浅谈 FFT

前言

这篇文章参考了很多地方的文章,包括 OI WIKI、洛谷题解等。我对这些东西进行了整理之后才有的这篇文章。顺便纠正了一些错误。

Made by: 2x6_81。感谢 henrytb 指出的一些错误。

概述

离散傅里叶变换(Discrete Fourier Transform,缩写为 DFT),是傅里叶变换在时域和频域上都呈离散的形式,将信号的时域采样变换为其 DTFT 的频域采样。

FFT 是一种高效实现 DFT 的算法,称为快速傅立叶变换(Fast Fourier Transform,FFT)。它对傅里叶变换的理论并没有新的发现,但是对于在计算机系统或者说数字系统中应用离散傅立叶变换,可以说是进了一大步。快速数论变换 (NTT) 是快速傅里叶变换(FFT)在数论基础上的实现。

在 1965 年,Cooley 和 Tukey 发表了快速傅里叶变换算法。事实上 FFT 早在这之前就被发现过了,但是在当时现代计算机并未问世,人们没有意识到 FFT 的重要性。一些调查者认为 FFT 是由 Runge 和 König 在 1924 年发现的。但事实上高斯早在 1805 年就发明了这个算法,但一直没有发表。——OI WIKI

注意:模块并不属于中国计算机学会划定的提高组知识点考察范围。

前置知识

多项式的度

对于一个多项式 f(x)f(x),称其最高次项的次数为该多项式的 度(Degree),记作 degf\operatorname{deg} f

任意角与弧度制

我们在初中学习过角度值,但是角度不是一个数,这给我们深入研究带来了一定的困难,还有其他的问题无法解释清,所以我们换用弧度制描述角。

首先我们用旋转的思路定义角,角可以看成平面内一条射线绕其端点从一个位置旋转到另一个位置形成的图形。开始的位置称为始边,结束的位置称为终边。

我们规定,按 逆时针 方向旋转形成的角叫做 正角,按 顺时针 方向旋转所形成的角叫做 负角,如果这条射线没有做任何旋转,称为 零角。这样我们就把角的概念推向了 任意角

然后我们介绍 弧度制,把长度等于半径长的弧所对的圆心角称为 1 弧度的角,用符号 rad\text{rad} 表示,读作:弧度。

一般地,正角的弧度数为正,负角的弧度数为负,零角的弧度数为 0,如果半径为 rr 的圆的圆心角 α\alpha 所对弧长为 ll,则 α=lrα=lr\vert \alpha \vert =\frac{l}{r} \vert \alpha \vert = \frac{l}{r}。利用这个公式还可以写出弧长和扇形面积公式,在此略过。

那么,我们发现 360360^\circ 的角弧度数为 2π2 \pi,这样有了对应关系之后,我们可以进行角度值和弧度制的转化了。

我们考虑一个角,将其终边再旋转一周,甚至多周,始边位置不动,那么终边位置永远是相同的,我们称这些角为 终边位置相同的角

与角 α\alpha 终边位置相同的角的集合很容易得出,为 {θθ=α+2kπ,kZ}\{ \theta \vert \theta = \alpha + 2k\pi, k \in \mathbb{Z} \}

可以理解为:给这个角的边不停加转一圈,终边位置不变。

任意角三角函数定义

在平面直角坐标系 xOyxOy 中设 β\angle \beta 的始边为 xx 轴的正半轴,设点 P(x,y)P(x, y)β\angle \beta 的终边上不与原点 OO 重合的任意一点,设 r=OPr=OP,令β=α\angle \beta = \angle \alpha,则有 sinα=yr\sin \alpha = \frac{y}{r}cosα=xr\cos \alpha = \frac{x}{r}tanα=yx\tan \alpha = \frac{y}{x} 。特别地,在 r=1r = 1 时,有 P(cosα,sinα)P(\cos \alpha, \sin \alpha)

复数

复数的基础知识

由于 1-1 不能开方,所以引入了一个数:虚数 i\text{i} ,满足 i2=1\text{i}^2 = -1

我们定义形如 a+bia + b \text{i} ,其中 a,bRa,b\in \mathbb{R} 的数称为 复数。其中 i\text{i} 被称为 虚数单位,全体复数的集合叫做 复数集

复数通常用 zz 表示,即 z=a+biz = a + b \text{i}。这种形式被称为 复数的代数形式。其中 aa 称为复数 zz实部,表示为 Re(z)\operatorname{Re}(z)bb 称为复数 zz虚部 ,表示为 Im(z)\operatorname{Im}(z)

我们可以用一个数对 (a,b)(a, b) 来表示一个复数 z=a+biz = a + b \text{i} 。可以发现 复数集与平面直角坐标系中的点集一一对应。于是,我们找到了复数的一种几何意义。

那么这个平面直角坐标系就不再一般,因为平面直角坐标系中的点具有了特殊意义——表示一个复数,所以我们把这样的平面直角坐标系称为 复平面xx 轴称为 实轴yy 轴称为 虚轴。我们进一步地说:复数集与复平面内所有的点所构成的集合是一一对应的

我们考虑到学过的平面向量的知识,发现向量的坐标表示也是一个有序实数对 (a,b)(a,b),显然,复数 z=a+biz=a+b\text{i} 对应复平面内的点 Z(a,b)Z(a,b),那么它还对应平面向量 OZ=(a,b)\overrightarrow{OZ}=(a,b),于是我们又找到了复数的另一种几何意义:复数集与复平面内的向量所构成的集合是一一对应的(实数 00 与零向量对应)

于是,我们由向量的知识迁移到复数上来,定义 复数的模 就是复数所对应的向量的模。复数 z=a+biz=a+b\text{i} 的模 z=a2+b2\vert z \vert = \sqrt{a^2+b^2}

于是为了方便,我们常把复数 z=a+biz=a+b\text{i} 称为点 ZZ 或向量 OZ\overrightarrow{OZ},并规定相等的向量表示同一个复数。

并且由向量的知识我们发现,虚数不可以比较大小(但是实数是可以的)。

复数的三角表示

对于一个复数 z=a+biz=a+b\text{i} ,我们让 r=a2+b2r = \sqrt{a^2 + b^2} ,那么复数 zz 可以表示为 r(cosθ+isinθ)r \cdot ( \cos \theta + \text{i} \sin \theta) ,其中 cosθ=ar\cos \theta = \frac{a}{r}sinθ=br\sin \theta = \frac{b}{r}

于是,复数有如下性质(以下 z1=r1(cosθ1+isinθ1)z_1 = r_1 \cdot (\cos \theta_1 + \text{i} \sin \theta_1)z2=r2(cosθ2+isinθ2)z_2 = r_2 \cdot (\cos \theta_2 + \text{i} \sin \theta_2)):

  1. z1z2=r1r2(cos(θ1+θ2)+isin(θ1+θ2))z_1 \cdot z_2 = r_1 \cdot r_2 \cdot (\cos (\theta_1 + \theta_2) + \text{i} \sin (\theta_1 + \theta_2))
  2. z1x=z1=r1x(cos(xθ1)+isin(xθ1))z_1^x = z_1 = r_1^x \cdot (\cos (x \cdot \theta_1) + \text{i} \sin (x \cdot \theta_1))
  3. 根据欧拉公式,有 z1=eiπθ1z_1 = e^{i\pi\theta_1}
  4. 1z1=z11=eiπθ1=r1(cos(θ1)+isin(θ1))=r1(cosθ1isinθ1)\frac{1}{z_1} = z_1^{-1} = e^{-i\pi\theta_1} = r_1 \cdot (\cos (-\theta_1) + \text{i} \sin (-\theta_1)) = r_1 \cdot (\cos \theta_1 - \text{i} \sin \theta_1)

单位复根

我们用 ωn\omega_n 表示 cos(2πn)+isin(2πn)\cos (\frac{2 \pi}{n}) + \text{i} \sin (\frac{2 \pi}{n}) ,由欧拉公式知 ωn=e2πin\omega_n = e^{\frac{2 \pi i}{n}}

单位复根有三个重要的性质。对于任意正整数 nn 和整数 kk ,有:

  1. ωnn=1\omega_n^n = 1
  2. ωnk=ω2n2k\omega_n^k = \omega_{2n}^{2k} ,也就是对于 2n2 \vert n ,都有 ωnk=ωn2k2\omega_n^k = \omega_{\frac n2}^{\frac k2}
  3. ω2nk+n=ω2nk\omega_{2n}^{k + n} = -\omega_{2n}^k ,也就是对于 2n2 \vert n ,都有 ωnk+n2=ωnk\omega_n^{k + \frac n2} = - \omega_n^k

这三个性质对后续有重要作用。

多项式的表示

系数表示法

系数表示法就是用一个多项式的各个项系数来表达这个多项式,即使用一个系数序列来表示多项式:

f(x)=a0+a1x+a2x2++anxnf(x) = a_0+a_1x+a_2x^2+\cdots +a_{n}x^{n}

点值表示法

点值表示法是把这个多项式看成一个函数,从上面选取 n+1n + 1 个点,从而利用这 n+1n + 1 个点来唯一地表示这个函数(你可以用高斯消元来理解唯一性)。

f(x)=a0+a1x+a2x2++anxnf(x)={a0,a1,,an}f(x) = a_0+a_1x+a_2x^2+\cdots +a_{n}x^{n} \Leftrightarrow f(x) = \{a_0, a_1, \cdots,a_{n}\}

f(x0)=y0=a0+a1x0+a2x02+a3x03++anx0nf(x1)=y1=a0+a1x1+a2x12+a3x13++anx1nf(x2)=y2=a0+a1x2+a2x22+a3x23++anx2nf(xn)=yn=a0+a1xn+a2xn2+a3xn3++anxnn\begin{array}{c} f(x_0) = y_0 = a_0 + a_1x_0+a_2x_0^2+a_3x_0^3+ \cdots + a_nx_0^n\\ f(x_1) = y_1 = a_0 + a_1x_1+a_2x_1^2+a_3x_1^3+ \cdots + a_nx_1^n\\ f(x_2) = y_2 = a_0 + a_1x_2+a_2x_2^2+a_3x_2^3+ \cdots + a_nx_2^n\\ \vdots\\ f(x_{n}) = y_{n} = a_0 + a_1x_{n}+a_2x_{n}^2+a_3x_{n}^3+ \cdots + a_nx_{n}^n \end{array}

那么用点值表示法表示 f(x)f(x) 如下:

f(x)=a0+a1x+a2x2++anxnf(x)={(x0,y0),(x1,y1),,(xn,yn)}f(x) = a_0+a_1x+a_2x^2+\cdots +a_{n}x^{n} \Leftrightarrow f(x) = \{(x_0,y_0),(x_1,y_1), \cdots,(x_n,y_{n})\}

数论

由欧拉定理可知,对 aZa \in \mathbb{Z}mNm \in \mathbb{N^*},若 gcd(a,m)=1\gcd (a,m)=1,则 aφ(m)1(modm)a^{\varphi(m)}\equiv 1 (\operatorname{mod} m)

因此满足同余式 an1(modm)a^n \equiv 1 (\operatorname{mod} m) 的最小正整数 nn 存在,这个 nn 称作 aamm,记作 δm(a)\delta_m(a)

这里有几个性质:

性质 1a,a2,,aδm(a)a,a^2,\cdots,a^{\delta_m(a)}mm 两两不同余。

性质 2:若 an1(modm)a^n \equiv 1(\operatorname{mod} m),则 δm(a)n\delta_m(a)∣n

性质 3:若 apaq(modm)a^p \equiv a^q (\operatorname{mod} m),则有 pq(modδm(a))p≡q(\operatorname{mod} \delta_m(a))

以下两个是与四则运算有关的重要性质。

性质 4:设 mNm \in \mathbb{N^*}a,bZa,b \in \mathbb{Z}gcd(a,m)=gcd(b,m)=1\gcd(a,m)=\gcd(b,m)=1,则 δm(ab)=δm(a)δm(b)\delta_m(ab)=\delta_m(a)\delta_m(b) 的充要条件是 gcd(δm(a),δm(b))=1\gcd\big(\delta_m(a),\delta_m(b)\big)=1

性质 5:设 kNk \in \mathbb{N}mNm \in \mathbb{N^*}aZa \in \mathbb{Z}gcd(a,m)=1\gcd(a,m)=1,则 δm(ak)=δm(a)gcd(δm(a),k)\delta_m(a^k)=\dfrac{\delta_m(a)}{\gcd\big(\delta_m(a),k\big)}

原根

原根:设 mNm \in \mathbb{N^∗}aZa \in \mathbb{Z}。若 gcd(a,m)=1\gcd(a,m)=1,且 δm(a)=φ(m)\delta_m(a)=\varphi(m),则称 aa 为模 mm原根。特别的,若 mm 为质数,则它的原根 aa 满足 aimodma^i \bmod mi[1,m1]i \in [1, m - 1]两两不同

原根判定定理:若一个数 gg 是模 mm 的原根,则有对于 φ(m)\varphi(m) 任何大于 11 且不为自身的因数 pp,都有 gφ(m)/p(modm)g^{φ(m)/p} \not\equiv (\operatorname{mod} m)

原根个数:若一个数 mm 有原根,则它原根的个数为 φ(φ(m))\varphi(\varphi(m))

原根存在定理:一个数 mm 存在原根当且仅当 m=2,4,pα,2pαm=2,4,p^{\alpha},2p^{\alpha},其中 pp 为奇素数,αN\alpha\in \mathbb{N^*}

最小原根的数量级:王元于 1959 年证明了若 mm 有原根,其最小原根是不多于 m0.25m^{0.25} 级别的。这保证了我们暴力找一个数的最小原根,复杂度是可以接受的。

卷积

对于一个序列,将其中元素一一映射到一个多项式函数的系数上, 这个多项式函数便叫做该序列的生成函数

形式化地讲,对于序列 f0,1,n1f_{0, 1, \cdots n - 1} ,它的生成函数就是 i=0n1fixi\sum_{i = 0}^{n - 1} f_i \cdot x^i

卷积即为生成函数的乘积在对应序列的变换上的的抽象,“卷”即为其作用效果,“积”即为其本质。

对于序列 f,gf, g ,其卷积序列 fgf \otimes g 满足 (fg)k=i=0kfigki=i+j=kfigj(f \otimes g)_k = \sum_{i = 0}^k f_i \cdot g_{k - i} = \sum_{i + j = k} f_i \cdot g_j ,其中 (fg)k(f \otimes g)_k 表示卷积序列 fgf \otimes g 的第 kk 项。

对于多项式 f,gf, g,其多项式的卷积为:fg=k=0degf+degg((i+j=kaibj)xk)f \otimes g = \sum_{k = 0}^{\operatorname{deg} f + \operatorname{deg} g} ((\sum_{i + j = k} a_i \cdot b_j)\cdot x_k)

DFT

理论知识

DFT(傅里叶变换)是利用 O(nlogn)O(n \log n) 的时间把多项式从系数表示转到了点值表示。

请注意,DFT 只能解决 n=2mn = 2^m 时的问题,否则在分治的时候左右不一样长。所以要在第一次 DFT 之前就把序列向上补成长度为 2m2^m(高次系数补 00)、最高项次数为 2m12^m - 1 的多项式。在讲述时,默认 n=2mn = 2^m ,其中 mNm \in \mathbb{N}

DFT 它分治地来求当 x=ωnkx = \omega_n^k 的时候 f(x)f(x) 的值。也就是

f(x)=a0+a1x+a2x2++anxnf(x)={(ωn0,f(ωn0)),(ωn1,f(ωn1)),,(ωnn,f(ωnn))}f(x) = a_0+a_1x+a_2x^2+\cdots +a_{n}x^{n} \Rightarrow f(x) = \{(\omega_n^0,f(\omega_n^0)),(\omega_n^1,f(\omega_n^1)), \cdots,(\omega_n^n,f(\omega_n^n))\}

相当于将数组 ff 的意义从多项式系数(f[i]=[xi]f(x)f[i] = [x^i]f(x))表示转到了点值表示(f[i]=f(ωni)f[i] = f(\omega_n^i))。

它的分治思想体现在将多项式分为奇次项和偶次项处理。举个例子,当 n=8n = 8 时,对于一共 88 项的多项式:

f(x)=a0+a1x+a2x2+a3x3+a4x4+a5x5+a6x6+a7x7f(x) = a_0 + a_1x + a_2x^2+a_3x^3+a_4x^4+a_5x^5+a_6x^6+a_7x^7

按照次数的奇偶来分成两组,然后右边提出来一个 xx

f(x)=(a0+a2x2+a4x4+a6x6)+(a1x+a3x3+a5x5+a7x7)=(a0+a2x2+a4x4+a6x6)+x(a1+a3x2+a5x4+a7x6)\begin{aligned} f(x) &= (a_0+a_2x^2+a_4x^4+a_6x^6) + (a_1x+a_3x^3+a_5x^5+a_7x^7)\\ &= (a_0+a_2x^2+a_4x^4+a_6x^6) + x(a_1+a_3x^2+a_5x^4+a_7x^6) \end{aligned}

分别用奇偶次次项数建立新的函数:

g(x)=a0+a2x+a4x2+a6x3h(x)=a1+a3x+a5x2+a7x3\begin{aligned} g(x) &= a_0+a_2x+a_4x^2+a_6x^3 \\ h(x) &= a_1+a_3x+a_5x^2+a_7x^3 \end{aligned}

那么原来的 f(x)f(x) 用新函数表示为:

f(x)=g(x2)+xh(x2)f(x)=g\left(x^2\right) + x \cdot h\left(x^2\right)

x=ωnkx = \omega_n^k 代入:

f(ωnk)=g(ωn2k)+ωnkh(ωn2k)=g(ωn2k)+ωnkh(ωn2k)\begin{aligned} f(\omega_n^k) &= g(\omega_n^{2k}) + \omega_n^k \cdot h(\omega_n^{2k}) \\ &= g(\omega_{\frac{n}{2}}^k) + \omega_n^k \cdot h(\omega_{\frac{n}{2}}^k) \end{aligned}

DFT 最神奇的地方就在于下面这个式子(将 x=ωnk+n2x = \omega_n^{k + \frac{n}{2}}​ 带入):

f(ωnk+n2)=g(ωn2k+n)+ωnk+n2h(ωn2k+n)=g(ωn2k)ωnkh(ωn2k)=g(ωn2k)ωnkh(ωn2k)\begin{aligned} f(\omega_n^{k + \frac{n}{2}}) &= g(\omega_n^{2k + n}) + \omega_n^{k + \frac {n}{2}} \cdot h(\omega_n^{2k + n}) \\ &= g(\omega_n^{2k}) - \omega_n^{k} \cdot h(\omega_n^{2k}) \\ &= g(\omega_{\frac{n}{2}}^k) - \omega_n^{k} \cdot h(\omega_{\frac{n}{2}}^k) \end{aligned}

也就是说我们只要知道 g(ωn2k)g(\omega_{\frac{n}{2}}^k)​ 和 h(ωn2k)h(\omega_{\frac{n}{2}}^k)​ 就可以算出 f(ωnk)f(\omega_n^k)​ 和 f(ωnk+n2)f(\omega_n^{k + \frac{n}{2}})​ 了!此时我们只要递归求出 g(ωn2k)g(\omega_{\frac{n}{2}}^k)​ 和 h(ωn2k)h(\omega_{\frac{n}{2}}^k)​ 即可。

假设算出所有 f(ωnk)f(\omega_n^k) 的时间复杂度为 T(n)T(n) ,则有 T(n)=2T(n/2)+n=O(nlogn)T(n) = 2 \cdot T(n / 2) + n = O(n \log n)

代码实现方面,STL 提供了复数的模板,当然也可以手动实现。两者区别在于,使用 STL 的 complex 可以调用 exp 函数求出 ωn\omega_n。但事实上使用欧拉公式得到的虚数来求 ωn\omega_n 也是等价的。

代码实现

OI WIKI 上的 DFT(STL 版):

#include <cmath>
#include <complex>

typedef std::complex<double> Comp;  // STL complex

const Comp I(0, 1);  // i
const int MAX_N = 1 << 20;

Comp tmp[MAX_N];

void DFT(Comp *f, int n) {
    if (n == 1) return;
    for (int i = 0; i < n; ++i) tmp[i] = f[i];
    for (int i = 0; i < n; ++i) {  // 偶数放左边,奇数放右边
        if (i & 1)
            f[n / 2 + i / 2] = tmp[i];
        else
            f[i / 2] = tmp[i];
    }
    Comp *g = f, *h = f + n / 2;
    DFT(g, n / 2), DFT(h, n / 2);  // 递归 DFT
    Comp cur(1, 0), step(cos(2 * M_PI / n), sin(2 * M_PI / n));
    // Comp step = exp(I * (2 * M_PI / n)); // 两个 step 定义是等价的
    for (int k = 0; k < n / 2; ++k) { 
        tmp[k] = g[k] + cur * h[k];
        tmp[k + n / 2] = g[k] - cur * h[k];
        cur *= step;
    }
    for (int i = 0; i < n; ++i) f[i] = tmp[i];
}

这里 g[k]g[k]​ 从多项式系数变为 g(ωn2k)g(\omega_{\frac{n}{2}}^k)​ ,h[k]h[k]​ 从多项式系数变为 h(ωn2k)h(\omega_{\frac{n}{2}}^k)​ ,step=ωnstep = \omega_n​ ,cur=ωnkcur = \omega_n^k​​ 。

IDFT

IDFT(傅里叶逆变换)是利用 O(nlogn)O(n \log n) 的时间把多项式从点值表示转到了系数表示。也就是:

f(x)={(ωn0,f(ωn0)),(ωn1,f(ωn1)),,(ωnn,f(ωnn))}f(x)=a0+a1x+a2x2++anxnf(x) = \{(\omega_n^0,f(\omega_n^0)),(\omega_n^1,f(\omega_n^1)), \cdots,(\omega_n^n,f(\omega_n^n))\} \Rightarrow f(x) = a_0+a_1x+a_2x^2+\cdots +a_{n}x^{n}

相当于将数组 ff 的意义从点值点值(f[i]=f(ωni)f[i] = f(\omega_n^i))表示转到了多项式系数(f[i]=[xi]f(x)f[i] = [x^i]f(x))。

对此我们有两种理解方式。

请注意,这里的 (ωnk,f(ωnk))(\omega_n^k, f(\omega_n^k))​ 与 DFT 中的 (ωnk,f(ωnk))(\omega_n^k, f(\omega_n^k))不同。这里的 (ωnk,f(ωnk))(\omega_n^k, f(\omega_n^k))可以不和 DFT 中的 (ωnk,f(ωnk))(\omega_n^k, f(\omega_n^k))相同

(也就是说,这里的 f(ωnk)f(\omega_n^k) 可能是 g(ωnk)h(ωnk)g(\omega_n^k) \cdot h(\omega_n^k) 等数的值)

理解方式一:线性代数角度

IDFT(傅里叶反变换)的作用,是把目标多项式的点值形式转换成系数形式。而 DFT 本身是个线性变换,可以理解为将目标多项式当作向量,左乘一个矩阵得到变换后的向量,以模拟把单位复根代入多项式的过程:

[y0y1y2y3yn1]=[111111ωn1ωn2ωn3ωnn11ωn2ωn4ωn6ωn2(n1)1ωn3ωn6ωn9ωn3(n1)1ωnn1ωn2(n1)ωn3(n1)ωn(n1)2][a0a1a2a3an1]\begin{bmatrix}y_0 \\ y_1 \\ y_2 \\ y_3 \\ \vdots \\ y_{n-1} \end{bmatrix} = \begin{bmatrix}1 & 1 & 1 & 1 & \cdots & 1 \\ 1 & \omega_n^1 & \omega_n^2 & \omega_n^3 & \cdots & \omega_n^{n-1} \\ 1 & \omega_n^2 & \omega_n^4 & \omega_n^6 & \cdots & \omega_n^{2(n-1)} \\ 1 & \omega_n^3 & \omega_n^6 & \omega_n^9 & \cdots & \omega_n^{3(n-1)} \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & \omega_n^{n-1} & \omega_n^{2(n-1)} & \omega_n^{3(n-1)} & \cdots & \omega_n^{(n-1)^2} \end{bmatrix} \begin{bmatrix} a_0 \\ a_1 \\ a_2 \\ a_3 \\ \vdots \\ a_{n-1} \end{bmatrix}

我们现在已经知道等号左边 [y0y1y2y3yn1]\begin{bmatrix}y_0 \\ y_1 \\ y_2 \\ y_3 \\ \vdots \\ y_{n-1} \end{bmatrix} 的结果了,我们的目标是求出 [a0a1a2a3an1]\begin{bmatrix} a_0 \\ a_1 \\ a_2 \\ a_3 \\ \vdots \\ a_{n-1} \end{bmatrix} 的结果。

根据矩阵的基础知识,我们只要求出中间的大矩阵的逆矩阵即可。

引理:对于一个 n×nn \times n 的矩阵:

[a1,1a1,2a1,3a1,4a1,na2,1a2,2a2,3a2,4a2,na3,1a3,2a3,3a3,4a3,na4,1a4,2a4,3a4,4a4,nan,1an,2an,3an,4an,n]\begin{bmatrix}a_{1,1} & a_{1,2} & a_{1,3} & a_{1,4} & \cdots & a_{1,n} \\ a_{2,1} & a_{2,2} & a_{2,3} & a_{2,4} & \cdots & a_{2,n} \\ a_{3,1} & a_{3,2} & a_{3,3} & a_{3,4} & \cdots & a_{3, n} \\ a_{4, 1} & a_{4, 2} & a_{4, 3} & a_{4, 4} & \cdots & a_{4, n} \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots \\ a_{n, 1} & a_{n, 2} & a_{n, 3} & a_{n, 4} & \cdots & a_{n, n} \end{bmatrix}

满足对于 i,j[1,n]\forall i, j \in [1, n] ,都有 ai,j=aj,ia_{i, j} = a_{j, i} ,则有

[a1,1a1,2a1,3a1,4a1,na2,1a2,2a2,3a2,4a2,na3,1a3,2a3,3a3,4a3,na4,1a4,2a4,3a4,4a4,nan,1an,2an,3an,4an,n]1=1n[a1,11a1,21a1,31a1,41a1,n1a2,11a2,21a2,31a2,41a2,n1a3,11a3,21a3,31a3,41a3,n1a4,11a4,21a4,31a4,41a4,n1an,11an,21an,31an,41an,n1]\begin{bmatrix}a_{1,1} & a_{1,2} & a_{1,3} & a_{1,4} & \cdots & a_{1,n} \\ a_{2,1} & a_{2,2} & a_{2,3} & a_{2,4} & \cdots & a_{2,n} \\ a_{3,1} & a_{3,2} & a_{3,3} & a_{3,4} & \cdots & a_{3, n} \\ a_{4, 1} & a_{4, 2} & a_{4, 3} & a_{4, 4} & \cdots & a_{4, n} \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots \\ a_{n, 1} & a_{n, 2} & a_{n, 3} & a_{n, 4} & \cdots & a_{n, n} \end{bmatrix}^{-1} = \frac {1}{n} \begin{bmatrix}a_{1,1}^{-1} & a_{1,2}^{-1} & a_{1,3}^{-1} & a_{1,4}^{-1} & \cdots & a_{1,n}^{-1} \\ a_{2,1}^{-1} & a_{2,2}^{-1} & a_{2,3}^{-1} & a_{2,4}^{-1} & \cdots & a_{2,n}^{-1} \\ a_{3,1}^{-1} & a_{3,2}^{-1} & a_{3,3}^{-1} & a_{3,4}^{-1} & \cdots & a_{3, n}^{-1} \\ a_{4, 1}^{-1} & a_{4, 2}^{-1} & a_{4, 3}^{-1} & a_{4, 4}^{-1} & \cdots & a_{4, n}^{-1} \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots \\ a_{n, 1}^{-1} & a_{n, 2}^{-1} & a_{n, 3}^{-1} & a_{n, 4}^{-1} & \cdots & a_{n, n}^{-1} \end{bmatrix}

而中间这个矩阵 [111111ωn1ωn2ωn3ωnn11ωn2ωn4ωn6ωn2(n1)1ωn3ωn6ωn9ωn3(n1)1ωnn1ωn2(n1)ωn3(n1)ωn(n1)2]\begin{bmatrix}1 & 1 & 1 & 1 & \cdots & 1 \\ 1 & \omega_n^1 & \omega_n^2 & \omega_n^3 & \cdots & \omega_n^{n-1} \\ 1 & \omega_n^2 & \omega_n^4 & \omega_n^6 & \cdots & \omega_n^{2(n-1)} \\ 1 & \omega_n^3 & \omega_n^6 & \omega_n^9 & \cdots & \omega_n^{3(n-1)} \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & \omega_n^{n-1} & \omega_n^{2(n-1)} & \omega_n^{3(n-1)} & \cdots & \omega_n^{(n-1)^2} \end{bmatrix} 可以用这个引理求得它的逆矩阵。

ωnk=(ωn1)k=cos(k2πn)isin(k2πn)\omega_n^{-k} = (\omega_n^{-1})^k = \cos (k \cdot \frac{2\pi}{n}) - \text{i} \cdot \sin (k \cdot \frac{2\pi}{n})

而其它的操作过程与 DFT 是完全相同的。我们可以定义一个函数,在里面加一个参数 11 或者是 1-1 ,然后把它乘到 sin(k2πn)\sin (k \cdot \frac{2\pi}{n}) 上。传入 11 就是 DFT,传入 1-1 就是 IDFT。

理解方式二:单位复根周期性

利用单位复根的周期性同样可以理解 IDFT 与 DFT 之间的关系。

考虑原本的多项式是 f(x)=a0+a1x+a2x2++an1xn1=i=0n1aixif(x) = a_0 + a_1 x + a_2 x^2 + \cdots + a_{n-1} x^{n-1} = \sum_{i = 0}^{n - 1}a_i x^i。而 IDFT 就是把点值表示还原为系数表示。

考虑 构造法。我们已知 yi=f(ωni),i{0,1,,n1}y_i=f\left( \omega_n^i \right),i\in\{0,1,\cdots,n-1\},求 {a0,a1,,an1}\{a_0,a_1,\cdots,a_{n-1}\}。构造多项式如下:

A(x)=i=0n1yixiA(x)=\sum_{i=0}^{n-1}y_ix^i

相当于把 {y0,y1,y2,,yn1}\{y_0,y_1,y_2,\cdots,y_{n-1}\} 当做多项式 AA 的系数表示法。

这时我们有两种推导方式,这对应了两种实现方法。

推导方法一

bi=ωnib_i=\omega_n^{-i} ,则多项式 AAx=b0,b1,,bn1x=b_0,b_1,\cdots,b_{n-1} 处的点值表示法为 {A(b0),A(b1),,A(bn1)}\left\{ A(b_0),A(b_1),\cdots,A(b_{n-1}) \right\}

A(x)A(x) 做一下变换,可以将 A(bk)A(b_k) 表示为:

A(bk)=i=0n1f(ωni)ωnik=i=0n1ωnikj=0n1aj(ωni)j=i=0n1j=0n1ajωni(jk)=j=0n1aji=0n1(ωnjk)i\begin{aligned} A(b_k)&=\sum_{i=0}^{n-1}f(\omega_n^i)\omega_n^{-ik}=\sum_{i=0}^{n-1}\omega_n^{-ik}\sum_{j=0}^{n-1}a_j(\omega_n^i)^{j}\\ &=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_j\omega_n^{i(j-k)}=\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}\left(\omega_n^{j-k}\right)^i \end{aligned}

下面记 S(n,a)=i=0n1(ωna)iS(n, a)=\sum_{i=0}^{n-1}\left(\omega_n^a\right)^i

可以发现,在 a0(modn)a \equiv 0 (\operatorname{mod} n) 时,S(n,a)=nS(n, a) = n

而在 a0(modn)a \not \equiv 0 (\operatorname{mod} n) 时,有:

S(ωna)=(ωna)n(ωna)0ωna1=0\begin{aligned} S\left(\omega_n^a\right)=\frac{\left(\omega_n^a\right)^n-\left(\omega_n^a\right)^0}{\omega_n^a-1}=0\end{aligned}

也就是说,S(n,a)S(n, a) 满足

S(n,a)={n,a0(modn)0,a0(modn)S\left(n, a\right)= \left\{\begin{aligned} n &, a \equiv 0 (\operatorname{mod} n) \\ 0 &, a \not \equiv 0 (\operatorname{mod} n) \end{aligned}\right.

将其带回原式:

A(bk)=j=0n1aji=0n1(ωnjk)i=j=0n1ajS(n,jk)=aknA(b_k)=\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}\left(\omega_n^{j-k}\right)^i=\sum_{j=0}^{n-1}a_jS\left(n, j-k\right)=a_k\cdot n

也就是说给定点 bi=ωnib_i=\omega_n^{-i},则 AA 的点值表示法为:

{(b0,A(b0)),(b1,A(b1)),,(bn1,A(bn1))}={(b0,a0n),(b1,a1n),,(bn1,an1n)}\begin{aligned} &\left\{ (b_0,A(b_0)),(b_1,A(b_1)),\cdots,(b_{n-1},A(b_{n-1})) \right\}\\ =&\left\{ (b_0,a_0\cdot n),(b_1,a_1\cdot n),\cdots,(b_{n-1},a_{n-1}\cdot n) \right\} \end{aligned}

说白了,就是说 ak=A(ωnk)n\begin{aligned} a_k = \frac{A(\omega_n^{-k})}{n} \end{aligned}

所以说,我们取单位根为 ωni\omega_n^{-i} ,对着 A(x)=i=0n1yixiA(x)=\sum_{i=0}^{n-1}y_ix^i 跑一遍 DFT 后将求得的数都除以 nn 即可。

推导方法二

我们直接将 ωnk\omega_n^k 代入 A(x)A(x)

A(ωnk)=i=0n1f(ωni)ωnik=i=0n1ωnikj=0n1aj(ωni)j=i=0n1j=0n1ajωni(j+k)=j=0n1aji=0n1(ωnj+k)i=A(ωnk)=j=0n1ajS(n,j+k)\begin{aligned} A(\omega^k_n)&=\sum_{i=0}^{n-1}f(\omega_n^i)\omega_n^{ik}=\sum_{i=0}^{n-1}\omega_n^{ik}\sum_{j=0}^{n-1}a_j(\omega_n^i)^{j}\\ &=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_j\omega_n^{i(j+k)}=\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}\left(\omega_n^{j+k}\right)^i = A(\omega_n^k) = \sum_{j=0}^{n-1}a_jS(n,j+k) \end{aligned}

这里 S(n,a)=i=0n1(ωna)iS(n, a)=\sum_{i=0}^{n-1}\left(\omega_n^a\right)^i

于是,当且仅当 j+k0(modn)j + k \equiv 0 (\operatorname{mod} n) 时有 S(n,j+k)=nS(n, j + k) = n ,其余情况 S(n,j+k)=0S(n, j + k) = 0 ,因此 A(ωnk)=anknA(\omega_n^k) = a_{n-k}\cdot n,其中 k1k \ge 1。而 A(ωn0)=a0nA(\omega_n^0) = a_0 \cdot n

这意味着我们将 {y0,y1,y2,,yn1}\{y_0,y_1,y_2,\cdots,y_{n-1}\} 做 DFT 变换后,a1ana_1 \sim a_n 反转再除以 nn,同样可以还原 f(x)f(x) 的系数表示。

代码实现

现在我们可以集 DFT 和 IDFT 于一身。代码实现如下(这里用的是推导方法一):

#include <cmath>
#include <complex>

typedef std::complex<double> Comp;  // STL complex

const Comp I(0, 1);  // i
const int MAX_N = 1 << 20;

Comp tmp[MAX_N];

void DFT(Comp *f, int n, int rev) {  // rev = 1, DFT; rev = -1, IDFT
    if (n == 1) return;
    for (int i = 0; i < n; ++i) tmp[i] = f[i];
    for (int i = 0; i < n; ++i) {  // 偶数放左边,奇数放右边
        if (i & 1)
            f[n / 2 + i / 2] = tmp[i];
        else
            f[i / 2] = tmp[i];
    }
    Comp *g = f, *h = f + n / 2;
    DFT(g, n / 2, rev), DFT(h, n / 2, rev);  // 递归 DFT
    Comp cur(1, 0), step(cos(2 * M_PI / n), rev * sin(2 * M_PI / n));
    for (int k = 0; k < n / 2; ++k) {
        tmp[k] = g[k] + cur * h[k];
        tmp[k + n / 2] = g[k] - cur * h[k];
        cur *= step;
    }
    for (int i = 0; i < n; ++i) f[i] = tmp[i];
}

这里 g[k]=g(ωn2k)g[k] = g(\omega_{\frac{n}{2}}^k)h[k]=h(ωn2k)h[k] = h(\omega_{\frac{n}{2}}^k)step=ωnstep = \omega_ncur=ωnkcur = \omega_n^k。别忘了最后还要除以 nn

FFT

位逆序置换

这个算法还可以从“分治”的角度继续优化。我们每一次都会把整个多项式的奇数次项和偶数次项系数分开,一直分到只剩下一个系数。但是,这个递归的过程需要更多的内存。因此,我们可以先“模仿递归”把这些系数在原数组中“拆分”,然后再“倍增”地去合并这些算出来的值。

88 项多项式为例,模拟拆分的过程:

  • 初始序列为 {x0,x1,x2,x3,x4,x5,x6,x7}\{x_0, x_1, x_2, x_3, x_4, x_5, x_6, x_7\}
  • 一次拆分之后 {x0,x2,x4,x6},{x1,x3,x5,x7}\{x_0, x_2, x_4, x_6\},\{x_1, x_3, x_5, x_7 \}
  • 两次拆分之后 {x0,x4}{x2,x6},{x1,x5},{x3,x7}\{x_0,x_4\} \{x_2, x_6\},\{x_1, x_5\},\{x_3, x_7 \}
  • 三次拆分之后 {x0}{x4}{x2}{x6}{x1}{x5}{x3}{x7}\{x_0\}\{x_4\}\{x_2\}\{x_6\}\{x_1\}\{x_5\}\{x_3\}\{x_7 \}

规律:其实就是原来的那个序列,每个下标用二进制表示,然后把二进制翻转对称一下,就是最终那个位置的下标。比如 x1x_1 是 001,翻转是 100,也就是 4,而且最后那个位置的下标确实是 4。我们称这个变换为位逆序置换(bit-reversal permutation,国内也称蝴蝶变换)

根据它的定义,我们可以在 O(nlogn)O(n \log n) 的时间内求出每个数变换后的结果,以下是代码:(OI WIKI 上的)

void change(Complex y[], int len) {
    int i, j, k;
    for (int i = 1, j = len / 2; i < len - 1; i++) {
        if (i < j) swap(y[i], y[j]);
        // 交换互为小标反转的元素,i<j 保证交换一次
        // i 做正常的 + 1,j 做反转类型的 + 1,始终保持 i 和 j 是反转的
        k = len / 2;
        while (j >= k) {
            j = j - k;
            k = k / 2;
        }
        if (j < k) j += k;
    }
}

我们实际上可以利用 O(n)O(n) 递推来实现该变换,以下是代码:(OI WIKI 上的)

// 同样需要保证 len 是 2 的幂
// 记 rev[i] 为 i 翻转后的值
void change(Complex y[], int len) {
    for (int i = 0; i < len; ++i) {
        rev[i] = rev[i >> 1] >> 1;
        if (i & 1) {  // 如果最后一位是 1,则翻转成 len/2
            rev[i] |= len >> 1;
        }
        // 2x6_81 注: 上面 4 行等价于 rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (len - 1));
    }
    for (int i = 0; i < len; ++i) {
        if (i < rev[i]) {  // 保证每对数只翻转一次
          swap(y[i], y[rev[i]]);
        }
    }
    return;
}

当然,这个位逆序置换不是没有道理的,可以这么理解:

还是以 88 项多项式为例,模拟拆分的过程(下标用二进制表示):

  • 初始序列为 {x000,x001,x010,x011,x100,x101,x110,x111}\{x_{000}, x_{001}, x_{010}, x_{011}, x_{100}, x_{101}, x_{110}, x_{111}\}
  • 一次拆分之后 {x000,x010,x100,x110},{x001,x011,x101,x111}\{x_{000}, x_{010}, x_{100}, x_{110}\},\{x_{001}, x_{011}, x_{101}, x_{111} \}
  • 两次拆分之后 {x000,x100}{x010,x110},{x001,x101},{x011,x111}\{x_{000},x_{100}\} \{x_{010}, x_{110}\},\{x_{001}, x_{101}\},\{x_{011}, x_{111}\}
  • 三次拆分之后 {x000},{x100}{x010},{x110},{x001},{x101},{x011},{x111}\{x_{000}\},\{x_{100}\} \{x_{010}\},\{x_{110}\},\{x_{001}\},\{x_{101}\},\{x_{011}\},\{x_{111}\}

可以发现,每次拆分就是一个按位的基数排序

第一次拆分之后,第 00 位为 00 的都在左边,为 11 的都在右边。

第二次拆分之后,每一组的第 11 位为 00 的都在左边,为 11 的都在右边。

以此类推,于是就是位逆序置换了。

代码实现(IDFT 推导方法一)

摘自 OI WIKI 的代码:

/*
 * 做 FFT
 * len 必须是 2^k 形式
 * on == 1 时是 DFT,on == -1 时是 IDFT
 */
void fft(Complex y[], int len, int on) {
	change(y, len);
    for (int h = 2; h <= len; h <<= 1) {                  // 模拟合并过程
        Complex wn(cos(2 * PI / h), sin(on * 2 * PI / h));  // 计算当前单位复根 (要根据标记取判断是否取倒数)
        for (int j = 0; j < len; j += h) {
            Complex w(1, 0);  // 计算当前单位复根
            for (int k = j; k < j + h / 2; k++) {
                Complex u = y[k];
                Complex t = w * y[k + h / 2];
                y[k] = u + t;  // 这就是把两部分分治的结果加起来
                y[k + h / 2] = u - t;
                // 后半个 “step” 中的ω一定和 “前半个” 中的成相反数
                // “红圈”上的点转一整圈“转回来”,转半圈正好转成相反数
                // 一个数相反数的平方与这个数自身的平方相等
                w = w * wn;
            }
        }
    }
    if (on == -1) {
        for (int i = 0; i < len; i++) {
            y[i].x /= len;
        }
    }
}

代码实现(IDFT 推导方法二)

摘自 OI WIKI 的代码:

/*
 * 做 FFT
 * len 必须是 2^k 形式
 * on == 1 时是 DFT,on == -1 时是 IDFT
 */
void fft(Complex y[], int len, int on) {
    change(y, len);
    for (int h = 2; h <= len; h <<= 1) {             // 模拟合并过程
        Complex wn(cos(2 * PI / h), sin(2 * PI / h));  // 计算当前单位复根 (不用取倒数)
        for (int j = 0; j < len; j += h) {
            Complex w(1, 0);  // 计算当前单位复根
            for (int k = j; k < j + h / 2; k++) {
                Complex u = y[k];
                Complex t = w * y[k + h / 2];
                y[k] = u + t;  // 这就是把两部分分治的结果加起来
                y[k + h / 2] = u - t;  // 这就是把两部分分治的结果相减
                w = w * wn;
            }
        }
    }
    if (on == -1) {
        reverse(y + 1, y + len); // 反转操作
        for (int i = 0; i < len; i++) {
            y[i].x /= len;
        }
    }
}

请注意,如果你的所求序列的数过大,你需要用 long double 才能放置被卡精度(例如 102310^{23} 级别)。

NTT

NTT(快速数论变换)解决的是多项式乘法带模数的情况,可以说有些受模数的限制,数也比较大。但是比较方便,毕竟没有复数部分。

由于 FFT 中有浮点数乘法,常数较大。而 NTT 中均为整数,常数相对较小。

理论知识

对于一个质数 pp,若它满足 p=qn+1p = q \cdot n + 1 ,其中 nn 满足 n=2m, mNn = 2^m,\ m \in \mathbb{N},那么它的原根 gqn1(modp)g^{qn} \equiv 1 (\operatorname{mod} p)

我们可以用 gig_i 表示 gp1ig^{\frac{p - 1}{i}} 。将 gig_i 看作 ωi\omega_i 的等价,可以发现 gig_i 满足与 ωi\omega_i 相似的性质,例如对于任意正整数 n=2mn = 2^m 和整数 kk ,有:

  1. ωnn=1gnn1(modp)\omega_n^n = 1 \Leftrightarrow g_n^n \equiv 1 (\operatorname{mod} p)

  2. ωnk=ω2n2kgnkg2n2k(modp)\omega_n^k = \omega_{2n}^{2k} \Leftrightarrow g_n^k \equiv g_{2n}^{2k} (\operatorname{mod} p)

  3. ω2nk+n=ω2nkg2nk+ng2nk(modp)\omega_{2n}^{k + n} = -\omega_{2n}^k \Leftrightarrow g_{2n}^{k + n} \equiv -g_{2n}^{k} (\operatorname{mod} p)

  4. gnn21(modp)g_n^{\frac{n}{2}} \equiv -1 (\operatorname{mod} p)

这里常见的模数 pp 有:

p=998244353=7×17×223+1,g=3p=998244353=7 \times 17 \times 2^{23}+1, g=3

p=1004535809=479×221+1,g=3p = 1004535809 = 479 \times 2^{21}+1, g=3

下面这个也能用:

p=469762049=7×226+1,g=3p = 469762049 = 7 \times 2^{26} + 1, g = 3

这里 gg 表示 pp 的一个原根。

NTT 部分与 FFT 一样,这里推导一下 INTT。

INTT 推导

可以从线性代数角度出发,通过根据标记取判断是否取模 pp 的逆元(理解方式一)。

也可以和 IDFT 推导方法二一样,考虑 构造法。我们已知 yi=f(gni),i{0,1,,n1}y_i=f\left(g_n^i \right),i\in\{0,1,\cdots,n-1\}​,求 {a0,a1,,an1}\{a_0,a_1,\cdots,a_{n-1}\}​。构造多项式如下:

A(x)=i=0n1yixiA(x)=\sum_{i=0}^{n-1}y_ix^i

相当于把 {y0,y1,y2,,yn1}\{y_0,y_1,y_2,\cdots,y_{n-1}\} 当做多项式 AA 的系数表示法。

我们直接将 gnkg_n^k 代入 A(x)A(x)

A(gnk)=i=0n1f(gni)gnik=i=0n1gnikj=0n1aj(gni)j=i=0n1j=0n1ajgni(j+k)=j=0n1aji=0n1(gnj+k)i=A(gnk)=j=0n1ajS(n,j+k)\begin{aligned} A(g^k_n)&=\sum_{i=0}^{n-1}f(g_n^i)g_n^{ik}=\sum_{i=0}^{n-1}g_n^{ik}\sum_{j=0}^{n-1}a_j(g_n^i)^{j}\\ &=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_jg_n^{i(j+k)}=\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}\left(g_n^{j+k}\right)^i = A(g_n^k) = \sum_{j=0}^{n-1}a_jS(n,j+k) \end{aligned}

这里 S(n,a)=i=0n1(gna)iS(n, a)=\sum_{i=0}^{n-1}\left(g_n^a\right)^i

a0(modn)a \equiv 0 (\operatorname{mod} n) 时,S(n,a)=nn(modp)S(n, a) = n \equiv n (\operatorname{mod} p)

a0(modn)a \not \equiv 0 (\operatorname{mod} n) 时,S(n,a)=(gna)n(gna)0gna1(gnn)a1gna10(modp)\begin{aligned} S(n, a) = \frac{(g_n^a)^n - (g_n^a)^0}{g_n^a - 1} \equiv \frac{(g_n^n)^a - 1}{g_n^a - 1} \equiv 0 (\operatorname{mod} p) \end{aligned}

于是,当且仅当 j+k0(modn)j + k \equiv 0 (\operatorname{mod} n) 时有 S(n,j+k)n(modp)S(n, j + k) \equiv n (\operatorname{mod} p) ,其余情况 S(n,j+k)0(modp)S(n, j + k) \equiv 0 (\operatorname{mod} p) ,因此 A(gnk)ankn(modp)A(g_n^k) \equiv a_{n-k}\cdot n (\operatorname{mod} p),其中 k1k \ge 1。而 A(gn0)a0n(modp)A(g_n^0) \equiv a_0 \cdot n (\operatorname{mod} p)

这意味着我们将 {y0,y1,y2,,yn1}\{y_0,y_1,y_2,\cdots,y_{n-1}\} 做 NTT 变换后,反转 a1ana_1 \sim a_n 后再乘以 nn 的逆元,同样可以还原 f(x)f(x) 的系数在模 pp 后的表示。

代码实现

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1 << 22 | 1;
const LL mod = 998244353;
int Rev[N];
inline LL qpow(LL n, LL p, LL mod = 998244353ll) {
    LL ret = 1ll;
    while (p) {
        if (p & 1) (ret *= n) %= mod;
        (n *= n) %= mod; p >>= 1;
    }
    return ret;
}
void NTT(LL *F, int n, int Sign) { // Sign = 1: NTT , Sign = -1: INTT
	for (int i = 0; i < n; i++) {
		if (i < Rev[i]) swap(F[i], F[Rev[i]]);
	}
	for (int len = 2; len <= n; len <<= 1) {
		int mid = len >> 1;
        LL wl = qpow(3, (mod - 1) / len);
		for (int i = 0; i < n; i += len) {
			LL Ang = 1;
			for (int j = 0; j < mid; j++) {
				LL tx = F[i + j], ty = (F[i + j + mid] * Ang) % mod;
				F[i + j] = (tx + ty) % mod; F[i + j + mid] = (tx - ty + mod) % mod;
				Ang = (Ang * wl) % mod;
			}
		}
	}
    if (Sign == -1) {
        for (int i = 1; i < (n >> 1); i++) swap(F[i], F[n - i]);
        LL Invn = qpow(n, mod - 2);
        for (int i = 0; i < n; i++) {
            (F[i] *= Invn) %= mod;
        }
    }
}

例题

P3803【模板】多项式乘法(FFT)

题目链接

给定一个 nn 次多项式 F(x)F(x) ,和一个 mm 次多项式 G(x)G(x) ,请求出 F(x)F(x)G(x)G(x) 的卷积。

1n,m1061 \le n, m \le 10^6 ,保证输入中的系数大于等于 00 且小于等于 99

样例输入:1 2\n1 2\n1 2 1 ,输出:1 4 5 2

题目解答

这是一道模板题。

思路:由于 nn​ 次多项式和 mm​ 次多项式的卷积是 n+mn + m​ 次的(即有 n+m+1n + m + 1​ 项),于是首先令 pp​ 为最小的 22​ 的幂且不小于 m+n+1m + n + 1​ 。(高次系数补 00​)

首先利用两遍 FFT:

F(x)=a0+a1x+a2x2++apxpF(x)={(ωp0,f(ωp0)),(ωp1,f(ωp1)),,(ωpp,f(ωpp))}F(x) = a_0+a_1x+a_2x^2+\cdots +a_{p}x^{p} \Rightarrow F(x) = \{(\omega_p^0,f(\omega_p^0)),(\omega_p^1,f(\omega_p^1)), \cdots,(\omega_p^p,f(\omega_p^p))\}

G(x)=b0+b1x+b2x2++bpxpG(x)={(ωp0,g(ωp0)),(ωp1,g(ωp1)),,(ωpp,g(ωpp))}G(x) = b_0+b_1x+b_2x^2+\cdots +b_{p}x^{p} \Rightarrow G(x) = \{(\omega_p^0,g(\omega_p^0)),(\omega_p^1,g(\omega_p^1)), \cdots,(\omega_p^p,g(\omega_p^p))\}​​

之后令 H(x)H(x)​ 为它们的卷积。那么思路就是利用一遍 IFFT:

H(x)={(ωp0,f(ωp0)g(ωp0)),(ωp1,f(ωp1)g(ωp1)),,(ωpp,f(ωpp)g(ωpp))}H(x)=c0+c1x+c2x2++cpxpH(x) = \{(\omega_p^0,f(\omega_p^0) \cdot g(\omega_p^0)),(\omega_p^1,f(\omega_p^1) \cdot g(\omega_p^1)), \cdots,(\omega_p^p,f(\omega_p^p) \cdot g(\omega_p^p))\} \Rightarrow H(x) = c_0+c_1x+c_2x^2+\cdots +c_{p}x^{p}​​​ 。

于是 c0,c1,cn+mc_0, c_1, \cdots c_{n + m}​ 就是答案。代码(用时 2.09s,空间 282.14MB):

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1 << 22 | 1;
struct Complex {
	double x, y; // x + yi
	Complex() {x = y = 0.0;}
	Complex(double xx, double yy) {x = xx; y = yy;}
	Complex operator + (Complex A) {
		return Complex(x + A.x, y + A.y);
	}
	Complex operator - (Complex A) {
		return Complex(x - A.x, y - A.y);
	}
	Complex operator * (Complex A) {
		return Complex(x * A.x - y * A.y, x * A.y + y * A.x);
	}
};
LL A[N], B[N];
Complex fA[N], fB[N], fAns[N], Ans[N];
int Rev[N];
void FFT(Complex *F, int n, int Sign) { // Sign = 1: FFT , Sign = -1: IFFT
	for (int i = 0; i < n; i++) {
		if (i < Rev[i]) swap(F[i], F[Rev[i]]);
	}
	for (int len = 2; len <= n; len <<= 1) {
		int mid = len >> 1;
		Complex wl = Complex(cos(M_PI / mid), Sign * sin(M_PI / mid));
		for (int i = 0; i < n; i += len) {
			Complex Ang = Complex(1, 0);
			for (int j = 0; j < mid; j++) {
				Complex tx = F[i + j], ty = F[i + j + mid];
				F[i + j] = tx + Ang * ty; F[i + j + mid] = tx - Ang * ty;
				Ang = Ang * wl;
			}
		}
	}
    if (Sign == -1) {
        for (int i = 0; i < n; i++) F[i].x /= n;
    }
}
int main() {
	int n, m, p = 0;
	scanf("%d%d", &n, &m);
	for (int i = 1; i <= n + 1; i++) scanf("%lld", A + i), fA[i - 1].x = A[i];
	for (int i = 1; i <= m + 1; i++) scanf("%lld", B + i), fB[i - 1].x = B[i];
	while ((1 << p) <= n + m) p++;
	for (int i = 1, End = (1 << p); i < End; i++) {
		Rev[i] = (Rev[i >> 1] >> 1) | ((i & 1) << (p - 1));
	}
	FFT(fA, 1 << p, 1); FFT(fB, 1 << p, 1); // FFT
	for (int i = 0, End = (1 << p); i < End; i++) {
		fAns[i] = fA[i] * fB[i];
	}
	FFT(fAns, 1 << p, -1); // IFFT
	for (int i = 0; i <= n + m; i++) {
		printf("%d%c", (int)(fAns[i].x + 0.5), " \n"[i == n + m]); // 四舍五入, 防卡精度
	}
	return 0;
}

NTT 版本的代码(用时 1.71s,空间 82.81MB):

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1 << 22 | 1;
const LL mod = 998244353;
LL A[N], B[N];
LL fA[N], fB[N], fAns[N], Ans[N];
int Rev[N];
LL qpow(LL x, LL p, LL mod = 998244353ll) {
    LL ret = 1ll;
    while (p) {
        if (p & 1) (ret *= x) %= mod;
        (x *= x) %= mod; p >>= 1;
    }
    return ret;
}
void NTT(LL *F, int n, int Sign) { // Sign = 1: NTT , Sign = -1: INTT
    for (int i = 0; i < n; i++) {
        if (i < Rev[i]) swap(F[i], F[Rev[i]]);
    }
    for (int len = 2; len <= n; len <<= 1) {
        int mid = (len >> 1);
        LL wl = qpow(3, (mod - 1) / len);
        for (int i = 0; i < n; i += len) {
            LL Ang = 1;
            for (int j = 0; j < mid; j++) {
                LL tx = F[i + j], ty = F[i + j + mid];
                F[i + j] = (tx + Ang * ty) % mod; F[i + j + mid] = ((tx - Ang * ty) % mod + mod) % mod;
                (Ang *= wl) %= mod;
            }
        }
    }
    if (Sign == -1) {
        LL Inv_n = qpow(n, mod - 2);
        for (int i = 1; i < (n >> 1); i++) swap(F[i], F[n - i]);
        for (int i = 0; i < n; i++) (F[i] *= Inv_n) %= mod;
    }
}
int main() {
    int n, m, p = 0;
	scanf("%d%d", &n, &m);
	for (int i = 1; i <= n + 1; i++) scanf("%lld", A + i), fA[i - 1] = A[i];
	for (int i = 1; i <= m + 1; i++) scanf("%lld", B + i), fB[i - 1] = B[i];
	while ((1 << p) <= n + m) p++;
	for (int i = 1, End = (1 << p); i < End; i++) {
		Rev[i] = (Rev[i >> 1] >> 1) | ((i & 1) << (p - 1));
	}
	NTT(fA, 1 << p, 1); NTT(fB, 1 << p, 1); // NTT
	for (int i = 0, End = (1 << p); i < End; i++) {
		fAns[i] = (fA[i] * fB[i]) % mod;
	}
	NTT(fAns, 1 << p, -1); // INTT
    for (int i = 0; i <= n + m; i++) {
        printf("%lld%c", fAns[i], " \n"[i == n + m]);
    }
	return 0;
}

可以发现,用 NTT 交模板题比用 FFT 交模板题的评测用时短,内存用的少,也少去了手写复数的麻烦,会用得比较多。

“两次变一次”优化

我们发现对于多项式 PPQQ,分别将其放在复数的实部和虚部,有 (P+Qi)2=P2Q2+2(PQ)i(P + Q \text{i})^2 = P^2 - Q^2 + 2(P \cdot Q) \text{i}

也就是说,我们可以先将一个多项式的实部和虚部分别放上要卷积的两个多项式 PPQQ ,再将这个多项式用 FFT,将该多项式平方后利用 IFFT 求出 (P+Qi)2(P + Q \text{i})^2 ,结果就是虚部再除以 22

可以发现,整个过程只用了一次 FFT(IFFT 不算),而前面的就要用两次。这就是“两次变一次”优化。代码(用时 1.79s,空间 285.25MB):

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1 << 22 | 1;
struct Complex {
	double x, y; // x + yi
	Complex() {x = y = 0.0;}
	Complex(double xx, double yy) {x = xx; y = yy;}
};
Complex operator + (Complex A, Complex B) {
	return Complex(A.x + B.x, A.y + B.y);
}
Complex operator - (Complex A, Complex B) {
	return Complex(A.x - B.x, A.y - B.y);
}
Complex operator * (Complex A, Complex B) {
	return Complex(A.x * B.x - A.y * B.y, A.x * B.y + A.y * B.x);
}
LL A[N], B[N];
Complex fA[N], fB[N], fAns[N], Ans[N];
int Rev[N];
void FFT(Complex *F, int n, int Sign) { // Sign = 1: FFT , Sign = -1: IFFT
	for (int i = 0; i < n; i++) {
		if (i < Rev[i]) swap(F[i], F[Rev[i]]);
	}
	for (int len = 2; len <= n; len <<= 1) {
		int mid = len >> 1;
		Complex wl = Complex(cos(M_PI / mid), Sign * sin(M_PI / mid));
		for (int i = 0; i < n; i += len) {
			Complex Ang = Complex(1, 0);
			for (int j = 0; j < mid; j++) {
				Complex tx = F[i + j], ty = F[i + j + mid];
				F[i + j] = tx + Ang * ty; F[i + j + mid] = tx - Ang * ty;
				Ang = Ang * wl;
			}
		}
	}
    if (Sign == -1) {
        for (int i = 0; i < n; i++) F[i] = Complex(F[i].x / n, F[i].y / n);
    }
}
int main() {
	int n, m, p = 0;
	scanf("%d%d", &n, &m);
	for (int i = 1; i <= n + 1; i++) scanf("%lld", A + i), fA[i - 1].x = A[i];
	for (int i = 1; i <= m + 1; i++) scanf("%lld", B + i), fA[i - 1].y = B[i];
	while ((1 << p) <= n + m) p++;
	for (int i = 1, End = (1 << p); i < End; i++) {
		Rev[i] = (Rev[i >> 1] >> 1) | ((i & 1) << (p - 1));
	}
	FFT(fA, 1 << p, 1);
	for (int i = 0, End = (1 << p); i < End; i++) {
		fAns[i] = fA[i] * fA[i];
	}
	FFT(fAns, 1 << p, -1);
	for (int i = 0; i <= n + m; i++) {
		printf("%d%c", (int)(fAns[i].y / 2.0 + 0.5), " \n"[i == n + m]); // 四舍五入防止卡精度
	}
	return 0;
}

P1919【模板】A*B Problem升级版(FFT快速傅里叶)

题目链接

给定两个正整数 a,ba, b ,求 a×ba \times b​ 。其中 1a,b1010000001 \le a, b \le 10^{1000000}

样例输入:114514 1919810 ,输出:219845122340

题目解答

对于一个 nn 位的十进制数,将其看作是一个 n1n - 1 次的多项式 AA ,满足:

A(x)=a0+a1×10+a2×102++an1×10n1A(x) = a_0 + a_1 \times 10 + a_2 \times 10^2 + \cdots + a_{n - 1} \times 10^{n - 1}

那么对于两个大整数相乘,我们就可以卷积了。于是就和模板题一样。

当然你最后还是要进位的。

不过,这道题不是很重要。如果你不想写这道题,并且你过了上一题,又想得到这道题的分数,那就直接用 Ruby 吧。

a = Integer(gets)
b = Integer(gets)
puts a * b

如果说 Ruby 过不去,你可以试试 Haskell:

main = do
a <- getLine
b <- getLine
print(read a * read b)

为了更快更好,你也可以使用 Python 3:

import decimal
decimal.getcontext().prec = 10000000
a = input(); b = input()
e = 1
if a[0] == '-': e = -e; a = a[1:]
a = '0.' + a
if b[0] == '-': e = -e; b = b[1:]
b = '0.' + b
a = decimal.Decimal(a); b = decimal.Decimal(b)
c = str(a * b); c = c[2:]
if c[0] == '0': c = c[1:]
if e == -1: c = '-' + c
print(c)

当然,作者也写了,但是用的是 C,就不放在这里了。

[ZJOI2014]力

题目链接

给定 nn 个数 q1,q2,qnq_1, q_2, \cdots q_n ,定义 Fj=i=1j1qiqj(ji)2i=j+1nqiqj(ji)2\begin{aligned} F_j = \sum_{i = 1}^{j - 1} \frac{q_i \cdot q_j}{(j - i) ^ 2} - \sum_{i = j + 1}^n \frac{q_i \cdot q_j}{(j - i)^2} \end{aligned} ,以及 Ei=Fiqi\begin{aligned} E_i = \frac{F_i}{q_i} \end{aligned} 。现在请你对 1in1 \le i \le n ,求 EiE_i 的值。

1n1051 \le n \le 10^50<qi<1090 < q_i < 10^9

样例输入:3 3.14 19.26 8.17 ,输出:-21.302 -5.030 20.045

题目解答

这题是一个凑卷积的题。

可以发现,这题中 Ei=j=1i1qj(ji)2j=i+1nqj(ji)2\begin{aligned} E_i = \sum_{j = 1}^{i - 1} \frac{q_j}{(j - i) ^ 2} - \sum_{j = i + 1}^n \frac{q_j}{(j - i)^2} \end{aligned} 。我们将 EiE_i 拆成两部分:Ai=j=1i1qj(ji)2\begin{aligned} A_i = \sum_{j = 1}^{i - 1} \frac{q_j}{(j - i) ^ 2} \end{aligned}Bi=j=i+1nqj(ji)2\begin{aligned} B_i = \sum_{j = i + 1}^n \frac{q_j}{(j - i)^2} \end{aligned}

对于 AiA_i ,有 Ai=j=1i1qj(ji)2=j=1i1qj1(ij)2\begin{aligned} A_i = \sum_{j = 1}^{i - 1} \frac{q_j}{(j - i) ^ 2} = \sum_{j = 1}^{i - 1} q_j \cdot \frac{1}{(i - j)^2} \end{aligned}

定义 gi=1i2\begin{aligned} g_i = \frac{1}{i^2} \end{aligned} ,那么有 Ai=j=1i1qj1(ij)2=j=1i1qjgij\begin{aligned} A_i = \sum_{j = 1}^{i - 1} q_j \cdot \frac{1}{(i - j)^2} = \sum_{j = 1}^{i - 1} q_j \cdot g_{i - j} \end{aligned} ,是个卷积形式,可以用 FFT 进行加速。

对于 BiB_i ,有 Bi=j=i+1nqj(ji)2=j=i+1nqj1(ji)2\begin{aligned} B_i = \sum_{j = i + 1}^n \frac{q_j}{(j - i)^2} = \sum_{j = i + 1}^n q_j \cdot \frac{1}{(j - i)^2} \end{aligned}

还是定义 gi=1i2\begin{aligned} g_i = \frac{1}{i^2} \end{aligned} ,那么有 Bi=j=i+1nqj1(ji)2=j=i+1nqjgji\begin{aligned} B_i = \sum_{j = i + 1}^{n} q_j \cdot \frac{1}{(j - i)^2} = \sum_{j = i + 1}^{n} q_j \cdot g_{j - i} \end{aligned}

可以发现,这并不是一个标准的卷积形式。我们可以将序列 qq 进行反转。令 pi=qn+1ip_i = q_{n + 1 - i} 。那么有 gi=1i2\begin{aligned} g_i = \frac{1}{i^2} \end{aligned} ,那么有 Bi=j=i+1nqjgji=j=i+1npn+1jgji\begin{aligned} B_i = \sum_{j = i + 1}^{n} q_j \cdot g_{j - i} = \sum_{j = i + 1}^{n} p_{n + 1 - j} \cdot g_{j - i} \end{aligned}

乍一看这不是卷积的形式,不过如果令 a=n+1ja = n + 1 - j ,那么 Bi=j=i+1npn+1jgji=a=1nipag(n+1i)a\begin{aligned} B_i = \sum_{j = i + 1}^{n} p_{n + 1 - j} \cdot g_{j - i} = \sum_{a = 1}^{n - i} p_a \cdot g_{(n + 1 - i) - a} \end{aligned}

当然也可以如下理解:Bi=j=i+1npn+1jgji=x+y=n+1ipxgy\begin{aligned} B_i = \sum_{j = i + 1}^{n} p_{n + 1 - j} \cdot g_{j - i} = \sum_{x + y = n + 1 - i} p_x \cdot g_{y} \end{aligned}

于是这又是一个卷积形式,用 FFT 就做完了。

最后将 AiA_i​ 和 BiB_i​ 分别计算后,Ei=AiBn+1iE_i = A_i - B_{n + 1 - i}​ 。

总结小 trick:

若要求类似于 Pk=j=i+kAiBjP_k = \sum_{j = i + k}^{} A_iB_j​​​​​​​ 的式子,可以将序列 A0,A1,,AnA_0, A_1, \cdots, A_n​​​​​​​ 翻转,那么 Pk=i+j=n+kAiBjP_k = \sum_{i + j = n + k} A_iB_j 。​​​​​

[AH2017/HNOI2017]礼物

题目链接

给定长度为 nn​ 序列 x1,2,,nx_{1, 2, \cdots , n}​ 和 y1,2,,ny_{1, 2, \cdots , n}​ ,满足 i[1,n]\forall i \in [1, n]​ ,都有 xi,yimx_i, y_i \le m​ 。需要找到两个整数 k0k \ge 0​ 和 cc​ ,使得 i=1n(xiyi+k+c)2\sum_{i = 1}^n (x_i - y_{i + k} + c)^2​ 最小。要求这个最小值。
其中对于 iN\forall i \in \mathbb{N^*} ,有  yn+i=yi\ y_{n + i} = y_i​​​​ 。

1n500001 \le n \le 500001xi,yim1001 \le x_i, y_i \le m \le 100​ 。

样例输入:5 6\n1 2 3 4 5\n6 3 3 4 5 ,输出:1

题目解答

这题和上题一样,也是 FFT 经典题(没准比上题还要简单)。

可以发现,S=i=1n(xiyi+k+c)2=i=1n(xi2+yi2)+2i=1n(xiyi)c+nc22i=1nxiyi+kS = \sum_{i = 1}^n (x_i - y_{i + k} + c)^2 = \sum_{i = 1}^n (x_i^2 + y_i^2) + 2\cdot \sum_{i = 1}^n(x_i - y_i) \cdot c + n \cdot c^2 - 2\cdot \sum_{i = 1}^n x_i \cdot y_{i + k}

可以将 SS 拆成三部分:i=1n(xi2+yi2)\sum_{i = 1}^n (x_i^2 + y_i^2)2i=1n(xiyi)c+nc22\cdot \sum_{i = 1}^n(x_i - y_i) \cdot c + n \cdot c^22i=1nxiyi+k- 2\cdot \sum_{i = 1}^n x_i \cdot y_{i + k}

对于第一部分:由于该值不受 k,ck, c 的影响,为恒定值,可以不用考虑。

对于第二部分:可将其看作一个关于 cc​​​​ 的二次函数。而我们要求的是该二次函数的最值。显然在 c=i=1n(xiyi)nc = -\frac{\sum_{i = 1}^n(x_i - y_i)}{n}​​​​ 时成立,而因为 cc​​​​ 必须为整数,因此需要看 i=1n(xiyi)n\lfloor -\frac{\sum_{i = 1}^n(x_i - y_i)}{n} \rfloor​​​​ 和 i=1n(xiyi)n\lceil -\frac{\sum_{i = 1}^n(x_i - y_i)}{n} \rceil​​​​ 与 i=1n(xiyi)n-\frac{\sum_{i = 1}^n(x_i - y_i)}{n}​​​​ 的距离。距离越小,这个二次函数的值越小。

对于第三部分:我们可以令多项式 F(x)=i=1nxn+1ixiF(x) = \sum_{i = 1}^{n} x_{n + 1 - i} x^iG(x)=i=1nyixi+i=n+12nyinxiG(x) = \sum_{i = 1}^n y_ix^i + \sum_{i = n + 1}^{2n} y_{i - n}x^i 。以及函数 h(k)=i=1nxiyi+kh(k) = \sum_{i = 1}^n x_i \cdot y_{i + k}

H(x)=F(x)×G(x)H(x) = F(x) \times G(x) ,那么有 h(k)=[xn+1+k]H(x)h(k) = [x^{n + 1 + k}] H(x)

而我们要求的是 maxi=0n1h(i)=maxi=0n1[xn+1+i]H(x)\begin{aligned} \max_{i = 0}^{n - 1} h(i) = \max_{i = 0}^{n - 1} [x^{n + 1 + i}] H(x) \end{aligned}​ ,正好需要把 H(x)H(x)​ 给求出来,于是就可以用 FFT 了。最后需要注意 FFT 的精度问题。

参考代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1 << 20 | 1;
struct Complex {
	double x, y; // x + yi
	Complex() {x = y = 0.0;}
	Complex(double xx, double yy) {x = xx; y = yy;}
};
Complex operator + (Complex A, Complex B) {
	return Complex(A.x + B.x, A.y + B.y);
}
Complex operator - (Complex A, Complex B) {
	return Complex(A.x - B.x, A.y - B.y);
}
Complex operator * (Complex A, Complex B) {
	return Complex(A.x * B.x - A.y * B.y, A.x * B.y + A.y * B.x);
}
LL A[N], B[N];
Complex fA[N], fB[N], fAns[N], Ans[N];
int Rev[N];
void FFT(Complex *F, int n, int Sign) { // Sign = 1: FFT , Sign = -1: IFFT
	for (int i = 0; i < n; i++) {
		if (i < Rev[i]) swap(F[i], F[Rev[i]]);
	}
	for (int len = 2; len <= n; len <<= 1) {
		int mid = len >> 1;
		Complex wl = Complex(cos(M_PI / mid), Sign * sin(M_PI / mid));
		for (int i = 0; i < n; i += len) {
			Complex Ang = Complex(1, 0);
			for (int j = 0; j < mid; j++) {
				Complex tx = F[i + j], ty = F[i + j + mid];
				F[i + j] = tx + Ang * ty; F[i + j + mid] = tx - Ang * ty;
				Ang = Ang * wl;
			}
		}
	}
    if (Sign == -1) {
        for (int i = 0; i < n; i++) F[i] = Complex(F[i].x / n, F[i].y / n);
    }
}
LL Floor(LL Up, LL Down) { // 分子 / 分母 向下取整
    if (Down < 0) Down = -Down, Up = -Up;
    if (Up >= 0) return Up / Down;
    return (Up - Down + 1) / Down;
}
LL Ceil(LL Up, LL Down) { // 分子 / 分母 向上取整
    if (Down < 0) Down = -Down, Up = -Up;
    if (Up <= 0) return Up / Down;
    return (Up + Down - 1) / Down;
}
int main() {
	int n, m, p = 0;
    LL fir = 0ll, sec = 0ll, thr = 0ll;
	scanf("%d%d", &n, &m);
	for (int i = 1; i <= n; i++) scanf("%lld", A + i), fA[n + 1 - i].x = A[i];
	for (int i = 1; i <= n; i++) scanf("%lld", B + i), fB[i + n].x = fB[i].x = B[i];
    // 第一部分
    for (int i = 1; i <= n; i++) fir += A[i] * A[i] + B[i] * B[i];
    // 第二部分
    LL tsum = 0ll, _Ceil, _Floor;
    for (int i = 1; i <= n; i++) tsum += A[i] - B[i];
    _Ceil = Ceil(-tsum, n) * n; _Floor = Floor(-tsum, n) * n;
    if (-tsum <= _Floor + ((_Ceil - _Floor) >> 1)) sec = 2 * tsum * (_Floor / n) + _Floor / n * _Floor;
    else sec = 2 * tsum * (_Ceil / n) + _Ceil / n * _Ceil;
    // 第三部分
	while ((1 << p) < 3 * n) p++;
	for (int i = 1, End = (1 << p); i < End; i++) {
		Rev[i] = (Rev[i >> 1] >> 1) | ((i & 1) << (p - 1));
	}
	FFT(fA, 1 << p, 1); FFT(fB, 1 << p, 1);
	for (int i = 0, End = (1 << p); i < End; i++) {
		fAns[i] = fA[i] * fB[i];
	}
	FFT(fAns, 1 << p, -1);
	for (int i = n + 1; i <= (n << 1); i++) {
        if (thr < (LL)(fAns[i].x + 0.5)) thr = (LL)(fAns[i].x + 0.5); // 减少精度误差
	}
    thr = (-2) * thr;
    // 结果输出
    printf("%lld\n", fir + sec + thr);
	return 0;
}

[SDOI2015]序列统计

题目链接

给定一个集合 SS ,满足 S{0,1,,m1}S \subseteq \{0, 1, \cdots, m - 1\} ,求从这个集合 SS 中有顺序地选取 nn 个数(可以重复),使得这 nn 个数的乘积 modm\operatorname{mod} mxx 的方案数 模上 10045358091004535809 的值。其中 mm 为质数。

对于 100%100\% 的数据,1n1091 \le n \le 10^93m80003 \le m \le 80001x<m1 \le x < m。保证 mm 为质数,且 SS 中的元素两两不同。

样例输入:4 3 1 2\n1 2 ,输出:8

样例解释:可以生成的满足要求的不同的数列有(1,1,1,1)(1,1,2,2)(1,2,1,2)(1,2,2,1)(2,1,1,2)(2,1,2,1)(2,2,1,1)(2,2,2,2)

题目解答

这是个较为经典的问题。下面引入一些简单的题目来进行解答。

例题 1

给定一个集合 SS,满足 S{0,1,,n}S \subseteq \{ 0, 1, \cdots , n \} ,在这个集合里选取 33 个数(可以重复),使得加和为 AA 的方案数。

例题 1 解答

可以运用桶的思想。设非负整数 xx 在集合 SS 中出现了 h(x)h(x) 次,那么可以得到答案为 i=0Aj=0Aih(i)×h(j)×h(Aij)\begin{aligned} \sum_{i = 0}^A \sum_{j = 0}^{A - i} h(i) \times h(j) \times h(A - i - j) \end{aligned}

时间复杂度是 O(A2)O(A^2) ,有点慢。不过我们可以先求出从 SS 中选取 22 个数的方案。设从 SS 中选取 22 个数的加和为 xx 的方案数为 h2(x)h_2(x) ,则可以得到答案为 i=0Ah(i)×h2(Ai)\begin{aligned} \sum_{i = 0}^A h(i) \times h_2(A - i) \end{aligned} 。可以 O(A)O(A) 求出。

那么如何快速求出 h2h_2 呢?可以发现,h2(x)=i=0xh(i)×h(xi)\begin{aligned} h_2(x) = \sum_{i = 0}^x h(i) \times h(x - i) \end{aligned} 。这是个标准的卷积形式,可以使用 FFT/NTT 快速求出 h2h_2

这样,总的复杂度就是 O(nlogn)O (n \log n) 了。

例题 2

给定一个集合 SS,满足 S{0,1,,n}S \subseteq \{ 0, 1, \cdots , n \} ,在这个集合里选取 kk 个数(可以重复),使得加和为 AA 的方案数。

例题 2 解答

类比于例题 1,不难得到一种做法:设从 SS 中选取 tt 个数的加和为 xx 的方案数为 ht(x)h_t(x) ,则有这样的递推式:ht(x)=i=0Ah(i)×ht1(xi)\begin{aligned} h_t(x) = \sum_{i = 0}^A h(i) \times h_{t - 1}(x - i) \end{aligned}

可以发现,这样做的话复杂度为 O(kAlogA)O(k A \log A) 的,还可以在进行优化。

可以考虑倍增。可以发现,选取 2r2^r 个数相当于在选取 2r12^{r - 1} 个数的 可能的和 的集合中选取 22 个数。那么就有递推式:h2r(x)=i=0Ah2r1(i)×h2r1(xi)\begin{aligned} h_{2^r}(x) = \sum_{i = 0}^A h_{2^{r - 1}}(i) \times h_{2^{r - 1}}(x - i) \end{aligned}

如此这般,就可以预处理处所有的 h2rh_{2^r} ,之后进行与快速幂类似的操作,将 kk 二进制拆解,可以在 O(AlogAlogk)O(A \log A \log k) 的复杂度下实现。

例题 3

给定一个集合 SS,满足 S{0,1,,n}S \subseteq \{ 0, 1, \cdots , n \} ,在这个集合里选取 kk 个数(可以重复),使得加和在 modm\operatorname{mod} m 的意义下为 AA 的方案数。

例题 3 解答

和例题 2 相比,这一题需要的是模意义下的加法,因此核心算法和上一题一样,只是有一个细节要注意:多项式乘法后,需要把和 m\ge m 的项累加到对应模运算以后的位置,因为这些也是可能的方案,对答案有贡献。

假设 ansians_i 表示在集合 SS 里选取 kk 个数(可以重复),使得加和在 modm\operatorname{mod} m 的意义下为 ii 的方案数。那么 ansi=j=0+hk(mj+i)ans_i = \sum_{j = 0}^{+ \infty} h_k(m \cdot j + i) 。(这里 ht(x)h_t(x) 的定义和上面同样)

回到原题

可以发现,原题和例题 3 之间,唯一的变化是加法变成了乘法。

因此,我们需要把乘法变成加法

于是我们可以这么干:

由于 mm 是质数,所以可以找到一个 mm 的原根 gg

可以发现,对于一个正整数 i[1,m1]i \in [1, m - 1] ,都有一个 qi[1,m1]q_i \in [1, m - 1] 满足 igqi(modm)i \equiv g^{q_i} (\operatorname{mod} m) 。且 qiq_i 两两不同。

于是乎原题

“对于所有满足 i\forall i 都有 aiSa_i \in S 的序列 aa 中,有多少个序列 aa 满足 i=1naix (modm)\prod_{i = 1}^n a_i \equiv x\ (\operatorname{mod} m)”)

等价于

“令集合 T={xxgy(modm),yS,x[1,m1]}T = \{ x \vert x \equiv g^y (\operatorname{mod} m), y \in S, x \in [1, m - 1]\}​ 。对于 i\forall i​ 都有 biTb_i \in T​ 的序列 bb​ 中,有多少个序列 bb​ 满足 i=1nbiqx(mod m1)\sum_{i = 1}^n b_i \equiv q_x (\operatorname{mod}\ m - 1)​”

于是乎,原题就转化成了例题 3。可喜可贺!

不过,这里貌似有个问题:00​ 也有可能属于集合 SS​ ,为什么不考虑呢?
因为 1x<m1 \le x < m​ ,因此 00​ 对答案无贡献,所以可以忽略 00​ 对答案的影响。

参考代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1 << 20 | 5;
const LL mod = 1004535809;
LL Ans[N], nowH[N], G[N];
int Rev[N], n, m, numTo, q[N];
inline LL qpow(LL n, LL p, LL mod = 1004535809ll) {
    LL ret = 1ll;
    while (p) {
        if (p & 1) (ret *= n) %= mod;
        (n *= n) %= mod; p >>= 1;
    }
    return ret;
}
void NTT(LL *F, int n, int Sign) { // Sign = 1: NTT , Sign = -1: INTT
	for (int i = 0; i < n; i++) {
		if (i < Rev[i]) swap(F[i], F[Rev[i]]);
	}
	for (int len = 2; len <= n; len <<= 1) {
		int mid = len >> 1;
        LL wl = qpow(3, (mod - 1) / len);
		for (int i = 0; i < n; i += len) {
			LL Ang = 1;
			for (int j = 0; j < mid; j++) {
				LL tx = F[i + j], ty = (F[i + j + mid] * Ang) % mod;
				F[i + j] = (tx + ty) % mod; F[i + j + mid] = (tx - ty + mod) % mod;
				Ang = (Ang * wl) % mod;
			}
		}
	}
    if (Sign == -1) {
        for (int i = 1; i < (n >> 1); i++) swap(F[i], F[n - i]);
        LL Invn = qpow(n, mod - 2);
        for (int i = 0; i < n; i++) {
            (F[i] *= Invn) %= mod;
        }
    }
}
void qpow(int n, int p) {
    // 倍增, 类比快速幂
    memset(Ans, 0, sizeof Ans); Ans[0] = 1;
    while (n) {
        if (n & 1) {
            for (int i = 0; i < m - 1; i++) G[i] = nowH[i];
            for (int i = m - 1; i < (1 << p); i++) G[i] = 0;
            NTT(Ans, 1 << p, 1); NTT(G, 1 << p, 1);
            for (int i = 0; i < (1 << p); i++) Ans[i] = (Ans[i] * G[i]) % mod;
            NTT(Ans, 1 << p, -1);
            for (int i = m - 1; i < (1 << p); i++) (Ans[i % (m - 1)] += Ans[i]) %= mod, Ans[i] = 0;
        }
        NTT(nowH, 1 << p, 1);
        for (int i = 0; i < (1 << p); i++) nowH[i] = (nowH[i] * nowH[i]) % mod;
        NTT(nowH, 1 << p, -1);
        for (int i = m - 1; i < (1 << p); i++) (nowH[i % (m - 1)] += nowH[i]) %= mod, nowH[i] = 0;
        n >>= 1;
    }
}
bool vis_for_g[N];
int main() {
    int tSlen, p = 0, g;
	scanf("%d%d%d%d", &n, &m, &numTo, &tSlen);
    // 1. 得到 m 的原根 g
    for (g = 1; g < m; g++) {
        memset(vis_for_g, 0, sizeof vis_for_g);
        int now = 1; bool flag = 1; // flag 判断是否为原根
        for (int i = 1; i < m; i++) {
            (now *= g) %= m; if (vis_for_g[now]) {flag = 0; break;}
            vis_for_g[now] = 1;
            // 如果存在 (i, j) 使得 g^i = g^j (mod m), 则 g 一定不是原根
        }
        if (flag) break; // 如果是原根, 就退出循环
    }
    // 1.1 求出 q_i
    for (int now = 1, i = 1; i < m; i++) {
        (now *= g) %= m; q[now] = i;
    }
    for (int num; tSlen--; ) { // 输入集合 S
        scanf("%d", &num);
        if (num) nowH[q[num] % (m - 1)] = 1; // 忽略 S 中的元素 0
    }
    // NTT 的准备
	while ((1 << p) <= (m << 1) - 2) p++;
	for (int i = 1, End = (1 << p); i < End; i++) {
		Rev[i] = (Rev[i >> 1] >> 1) | ((i & 1) << (p - 1));
	}
    // 2. 倍增
    qpow(n, p);
    // 在集合 S 里选取 n 个数(可以重复), 使得加和在 mod m 的意义下为 q_x 的方案数
    printf("%lld\n", Ans[q[numTo] % (m - 1)]);
	return 0;
}

CF755G PolandBall and Many Other Balls

题目链接

nn 个球在一排上,定义一个可以只包含 11 个球或者包含 22 个相邻的球,而一个球最多可以分到一个中,对于 m[1,k]\forall m \in [1, k] ,求从这些球中取出 mm 组的方案数 模上 998244353998244353 的值。

1n1051 \le n \le 10^51k<2151 \le k < 2^{15}

样例输入:3 3 ,输出:5 5 1

样例解释:
k=1k = 1 时,有 (1),(2),(3),(12),(23)(1), (2), (3), (12), (23) ,共 55 种方案。
k=2k = 2 时,有 (12)(3),(1)(23),(1)(2),(1)(3),(2)(3)(12)(3),(1)(23),(1)(2),(1)(3),(2)(3) ,共 55 种方案。
k=3k = 3 时,有 (1)(2)(3)(1)(2)(3) ,共 11 种方案。

题目解答

这是个利用 FFT/NTT 来优化 dp 的题目。

先来看一个显然的 dp:假设 fn,kf_{n, k}​ 为 nn​ 个球,取出 kk​ 组的方法数。那么 fn,k=fn1,k+fn1,k1+fn2,k1f_{n, k} = f_{n - 1, k} + f_{n - 1, k - 1} + f_{n - 2, k - 1}​​,也就是考虑最后一个球的分组情况(不在任何一组、在新的一组中、和前面的求在一组)。

然而只有这个 dp 还不够,我们再来看一个比较神奇的转移方式:

将一段长度为 aa 的球和一段长度为 bb 的球合并成一段长度为 a+ba + b 的球(长度为 aa 的在左,长度为 bb 的在右),那么要分两种情况讨论:

第一种:长度为 aa 的一段球的最后一个球和长度为 bb 的一段球的第一个球不放在同一组里,那么 fa+b,k=i=0kfa,ifb,kif_{a + b, k} = \sum_{i = 0}^k f_{a, i} \cdot f_{b, k - i}​ 。

第二种:长度为 aa​ 的一段球的最后一个球和长度为 bb​ 的一段球的第一个球放在同一组里,那么 fa+b,k=i=0k1fa1,i1fb1,ki1f_{a + b, k} = \sum_{i = 0}^{k - 1} f_{a - 1, i - 1} \cdot f_{b - 1, k - i - 1}​ 。

汇总一下,也就是 fa+b,k=i=0kfa,ifb,ki+i=0k1fa1,ifb1,ki1\begin{aligned} f_{a + b, k} = \sum_{i = 0}^k f_{a, i} \cdot f_{b, k - i} + \sum_{i = 0}^{k - 1} f_{a - 1, i} \cdot f_{b - 1, k - i - 1} \end{aligned}​​​ 。

Fn(x)=k=0+fn,kxk\begin{aligned} F_n(x) = \sum_{k = 0}^{+\infty} f_{n, k} x^k \end{aligned}​​​ 。那么上面的式子可以变为:Fa+b(x)=Fa(x)Fb(x)+xFa1(x)Fb1(x)\begin{aligned} F_{a + b}(x) = F_a(x)F_b(x) + x F_{a - 1}(x)F_{b - 1}(x)\end{aligned}​​ 。(最后一个多项式要乘 xx​​ 是为了让次数总合为 k1k - 1 的项转移到次数为 kk

而第一个式子可以变为 Fn(x)=Fn1(x)+xFn1(x)+xFn2(x)F_n(x) = F_{n - 1}(x) + xF_{n - 1}(x) + xF_{n - 2}(x)

于是就可以如下倍增:

{F2n(x)=Fn2(x)+xFn12(x)F2n1(x)=Fn(x)Fn1(x)+xFn1(x)Fn2(x)F2n2(x)=Fn12(x)+xFn22(x)\begin{cases} F_{2n}(x) = F_n^2(x) + xF_{n - 1}^2(x) \\ F_{2n - 1}(x) = F_n(x)F_{n - 1}(x) + xF_{n - 1}(x)F_{n - 2}(x) \\ F_{2n - 2}(x) = F_{n - 1}^2(x) + xF_{n - 2}^2(x) \end{cases}

也就是说,只要知道 Fn2(x),Fn1(x),Fn(x)F_{n - 2}(x), F_{n - 1}(x), F_n(x) ,就可以知道 F2n2(x),F2n1(x),F2n(x)F_{2n - 2}(x), F_{2n - 1}(x), F_{2n}(x) 了。

而我们通过一开始的式子就可以有 Fn(x)Fn+1(x)F_n(x) \rightarrow F_{n + 1}(x) 。于是乎就可以倍增 FFT 了。

时间复杂度 O(nlog2n)O(n \log^2 n)​​ 。初始化 F0(x)=1F_0(x) = 1​​​ 。

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1 << 16 | 1;
const LL mod = 998244353ll;
LL F1[N], F2[N], F3[N];    // 倍增前 (n - 2, n - 1, n)
LL F21[N], F22[N], F23[N]; // 倍增后 (2n - 2, 2n - 1, 2n)
int Rev[N], n, k;
inline LL qpow(LL n, LL p, LL mod = 998244353ll) {
    LL ret = 1ll;
    while (p) {
        if (p & 1) (ret *= n) %= mod;
        (n *= n) %= mod; p >>= 1;
    }
    return ret;
}
void NTT(LL *F, int n, int Sign) { // Sign = 1: NTT , Sign = -1: INTT
	for (int i = 0; i < n; i++) {
		if (i < Rev[i]) swap(F[i], F[Rev[i]]);
	}
	for (int len = 2; len <= n; len <<= 1) {
		int mid = len >> 1;
        LL wl = qpow(3, (mod - 1) / len);
		for (int i = 0; i < n; i += len) {
			LL Ang = 1;
			for (int j = 0; j < mid; j++) {
				LL tx = F[i + j], ty = (F[i + j + mid] * Ang) % mod;
				F[i + j] = (tx + ty) % mod; F[i + j + mid] = (tx - ty + mod) % mod;
				Ang = (Ang * wl) % mod;
			}
		}
	}
    if (Sign == -1) {
        for (int i = 1; i < (n >> 1); i++) swap(F[i], F[n - i]);
        LL Invn = qpow(n, mod - 2);
        for (int i = 0; i < n; i++) {
            (F[i] *= Invn) %= mod;
        }
    }
}
void Pow2() {  // n -> 2n
    NTT(F1, 65536, 1); NTT(F2, 65536, 1); NTT(F3, 65536, 1);
    LL gn = qpow(3, (mod - 1) >> 16), now = 1ll;
    // F2n(x) = Fn(x)^2 + xFn-1(x)^2
    now = 1ll;
    for (int i = 0; i < 65536; i++, (now *= gn) %= mod) F23[i] = (F3[i] * F3[i] + now * ((F2[i] * F2[i]) % mod)) % mod;
    NTT(F23, 65536, -1);
    // F2n-1(x) = Fn(x)Fn-1(x) + xFn-1(x)Fn-2(x)
    now = 1ll;
    for (int i = 0; i < 65536; i++, (now *= gn) %= mod) F22[i] = (F3[i] * F2[i] + now * ((F2[i] * F1[i]) % mod)) % mod;
    NTT(F22, 65536, -1);
    // F2n-2(x) = Fn-1(x)^2 + xFn-2(x)^2
    now = 1ll;
    for (int i = 0; i < 65536; i++, (now *= gn) %= mod) F21[i] = (F2[i] * F2[i] + now * ((F1[i] * F1[i]) % mod)) % mod;
    NTT(F21, 65536, -1);
    // (F1, F2, F3) <- (F21, F22, F23)
    for (int i = 0; i < 65536; i++) F1[i] = F21[i], F2[i] = F22[i], F3[i] = F23[i];
}
void Plus1() {  // n -> n + 1
    // (F1, F2, F3) <- (F2, F3, *)
    for (int i = 0; i < 65536; i++) F1[i] = F2[i], F2[i] = F3[i];
    NTT(F1, 65536, 1); NTT(F2, 65536, 1);
    LL gn = qpow(3, (mod - 1) >> 16), now = 1ll;
    // Fn(x) = Fn-1(x) + xFn-1(x) + xFn-2(x)
    for (int i = 0; i < 65536; i++, (now *= gn) %= mod) F3[i] = (F2[i] + now * F2[i] + now * F1[i]) % mod;
    NTT(F3, 65536, -1); NTT(F2, 65536, -1); NTT(F1, 65536, -1);
}
int main() {
	int n, k; scanf("%d%d", &n, &k);
    for (int i = 1; i < 65536; i++) Rev[i] = (Rev[i >> 1] >> 1) | ((i & 1) << (16 - 1));
    F3[0] = 1; // 初始化
    for (int i = 30; i >= 0; i--) {
        Pow2();
        for (int j = 32768; j < 65537; j++) F1[j] = F2[j] = F3[j] = 0ll;
        if ((n & (1 << i)) > 0) Plus1();
        for (int j = 32768; j < 65537; j++) F1[j] = F2[j] = F3[j] = 0ll;
    }
    for (int i = 1; i <= k; i++) printf("%lld%c", F3[i], " \n"[i == k]);
	return 0;
}

P4173 残缺的字符串

题目链接

给定长度为 mm 且带通配符的模式串 AA ,长度为 nn 且带通配符文本串 BB ,需要求出所有位置 pp ,满足 BB 串从第 pp 个字符开始的连续 mm 个字符,与 AA​ 串匹配。

题目解答

这是一道用 FFT/NTT 来解决字符串上的问题。

这是一道带通配符的单模式串匹配题。在解决这道题之前,我们从普通的单模式串匹配开始说起。

例题

给定长度为 mm 的模式串 AA ,长度为 nn 的文本串 BB ,需要求出所有位置 pp ,满足 BB 串从第 pp 个字符开始的连续 mm 个字符,与 AA​ 串完全相同。

例题解答

这道题固然可以用 KMP 去解,但是今天我们要用 FFT 来解决问题。不要因为 KMP 的复杂度优秀而忽略这一段,这段对解题很重要

为了方便说明,我们约定所有字符串的下标从 00 开始。AxA_x 表示 AA 串中下标为 xx 的字符的 ASCII 码,ByB_y 表示 BB 串中下标为 yy 的字符的 ASCII 码。

定义匹配函数 f(x,y)=(AxBy)2f(x, y) = (A_x - B_y)^2​​​ 。那么我们可以这样定义“匹配”:若 f(x,y)=0f(x, y) = 0​​ ,那么称 AA​​ 的第 xx​​ 个字符和 BB​​ 的第 yy​​ 个字符匹配,再定义完全匹配函数 P(x)=i=0m1f(i,x+i)P(x) = \sum_{i = 0}^{m - 1} f(i, x + i)​​ 。若 P(x)=0P(x) = 0​​ ,则称 BB​​ 以第 xx​​ 位开始的连续 mm​​ 位,与 AA​​ 完全匹配。

为什么这里是 (AxBy)2(A_x - B_y)^2 而非 (AxBy)(A_x - B_y) 呢?如果 P(x)=i=0m1(AxBy)P(x) = \sum_{i = 0}^{m - 1} (A_x - B_y) ,那么在此定义下,"osu" 和 "suo" 就是完全匹配的了。也就是说,后者的必要条件并非两个字符串完全相等。于是乎,我们就要让 P(x)P(x) 满足以下条件:

  • P(x)=0P(x) = 0 的充分必要条件为 BB 以第 xx 位开始的连续 mm 位相等。

发现在 (AxBy)(A_x - B_y) 两边加上一个绝对值可以。满足要求但是这样似乎就只能暴力计算。于是在 (AxBy)(A_x - B_y) 加上平方,就解决了这个问题。

回到函数 P(x)P(x)​ 上来。将 P(x)P(x)​ 拆开,有:P(x)=i=0m1f(i,x+i)=i=0m1(AiBx+i)2=i=0m1Ai2+i=0m1Bx+i22i=0m1AiBx+i\begin{aligned} P(x) &= \sum_{i = 0}^{m - 1} f(i, x + i) = \sum_{i = 0}^{m - 1} (A_i - B_{x + i})^2 = \sum_{i = 0}^{m - 1} A_i^2 + \sum_{i = 0}^{m - 1} B_{x + i}^2 - 2\sum_{i = 0}^{m - 1} A_iB_{x + i} \end{aligned}​ 。

前两项可以通过前缀和预处理,而第三项可以利用 trick 进行卷积。

回到原题

原题显然用 KMP 就无法解决了,我们还是考虑和上面类似的方法。那么我们回顾上面的普通串匹配过程,我们可以总结出思路大概是这样的:

定义匹配函数 \rightarrow 定义完全匹配函数 \rightarrow 快速计算每一位的完全匹配函数值

我们设通配符的数值为 00 ,定义匹配函数 f(x,y)=(AxBy)2AxByf(x, y) = (A_x - B_y)^2A_xB_y ,那么完全匹配函数 P(x)=i=0m1f(i,i+x)P(x) = \sum_{i = 0}^{m - 1} f(i, i + x)

P(x)P(x)​​ 拆开,有:P(x)=i=0m1f(i,i+x)=i=0m1(AiBi+x)2AiBi+x=i=0m1Ai3Bi+x2i=0m1Ai2Bi+x2+i=0m1AiBi+x3\begin{aligned} P(x) &= \sum_{i = 0}^{m - 1} f(i, i + x) = \sum_{i = 0}^{m - 1} (A_i - B_{i + x})^2A_iB_{i + x} = \sum_{i = 0}^{m - 1} A_i^3B_{i + x} - 2\sum_{i = 0}^{m - 1} A_i^2B_{i + x}^2 + \sum_{i = 0}^{m - 1} A_iB_{i + x}^3 \end{aligned}​​ 。

于是这三项都可以利用 trick 进行卷积。

总共进行 6 次 FFT 和 1 次 IFFT,时间复杂度 O(nlogn)O(n \log n) ,常数不小。所以如果 TLE 了,就开 O2 优化吧。代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1 << 21 | 5;
const LL mod = 998244353ll;
struct Complex {
	double x, y; // x + yi
	Complex() {x = y = 0.0;}
	Complex(double xx, double yy) {x = xx; y = yy;}
};
Complex operator + (Complex A, Complex B) {
	return Complex(A.x + B.x, A.y + B.y);
}
Complex operator - (Complex A, Complex B) {
	return Complex(A.x - B.x, A.y - B.y);
}
Complex operator * (Complex A, Complex B) {
	return Complex(A.x * B.x - A.y * B.y, A.x * B.y + A.y * B.x);
}
LL A[N], B[N];
Complex fA[N], fB[N], fAns[N];
int Rev[N];
char ch1[N], ch2[N];
void FFT(Complex *F, int n, int Sign) { // Sign = 1: FFT , Sign = -1: IFFT
	for (int i = 0; i < n; i++) {
		if (i < Rev[i]) swap(F[i], F[Rev[i]]);
	}
	for (int len = 2; len <= n; len <<= 1) {
		int mid = len >> 1;
		Complex wl = Complex(cos(M_PI / mid), Sign * sin(M_PI / mid));
		for (int i = 0; i < n; i += len) {
			Complex Ang = Complex(1, 0);
			for (int j = 0; j < mid; j++) {
				Complex tx = F[i + j], ty = F[i + j + mid];
				F[i + j] = tx + Ang * ty; F[i + j + mid] = tx - Ang * ty;
				Ang = Ang * wl;
			}
		}
	}
    if (Sign == -1) {
        for (int i = 0; i < n; i++) F[i] = Complex(F[i].x / n, 0);
    }
}
int main() {
    int n, m, p = 0;
    scanf("%d%d%s%s", &n, &m, ch1, ch2);
    for (int i = 0; i < n; i++) A[i] = ch1[n - 1 - i] == '*' ? 0 : (ch1[n - 1 - i] - 'a' + 1);
    for (int i = 0; i < m; i++) B[i] = ch2[i] == '*' ? 0 : (ch2[i] - 'a' + 1);
    while ((1 << p) <= n + m) p++;
    for (int i = 1; i < (1 << p); i++) {
        Rev[i] = (Rev[i >> 1] >> 1) | ((i & 1) << (p - 1));
    }
    // AB(A - B)^2 = A^3B - 2A^2B^2 + AB^3
    // 1. A^3B
    for (int i = 0; i < (1 << p); i++) fA[i] = fB[i] = Complex();
    for (int i = 0; i < n; i++) fA[i] = Complex(A[i] * A[i] * A[i], 0);
    for (int i = 0; i < m; i++) fB[i] = Complex(B[i], 0);
    FFT(fA, 1 << p, 1); FFT(fB, 1 << p, 1);
    for (int i = 0; i < (1 << p); i++) fAns[i] = fAns[i] + fA[i] * fB[i];
    // 2. A^2B^2
    for (int i = 0; i < (1 << p); i++) fA[i] = fB[i] = Complex();
    for (int i = 0; i < n; i++) fA[i] = Complex(A[i] * A[i], 0);
    for (int i = 0; i < m; i++) fB[i] = Complex(B[i] * B[i], 0);
    FFT(fA, 1 << p, 1); FFT(fB, 1 << p, 1);
    for (int i = 0; i < (1 << p); i++) fAns[i] = fAns[i] - fA[i] * fB[i] - fA[i] * fB[i];
    // 3. AB^3
    for (int i = 0; i < (1 << p); i++) fA[i] = fB[i] = Complex();
    for (int i = 0; i < n; i++) fA[i] = Complex(A[i], 0);
    for (int i = 0; i < m; i++) fB[i] = Complex(B[i] * B[i] * B[i], 0);
    FFT(fA, 1 << p, 1); FFT(fB, 1 << p, 1);
    for (int i = 0; i < (1 << p); i++) fAns[i] = fAns[i] + fA[i] * fB[i];
    // IFFT
    FFT(fAns, 1 << p, -1);
    int ans = 0;
    for (int i = n - 1; i < m; i++) if ((LL)(fAns[i].x + 0.5) < 1e-6) ans++;
    printf("%d\n", ans);
    for (int i = n - 1; i < m; i++) if ((LL)(fAns[i].x + 0.5) < 1e-6) printf("%d ", i - n + 2);
    return 0 & puts("");
}

请注意,理论上来说,使用 NTT 并不正确但是因为数据太水,所以把所有 FFT 换成 NTT 也可以通过此题

P4199 万径人踪灭

题目链接

在一个只包含字符 ab 的字符串中选取一个子序列,使得位置和字符都关于一条对称轴对称,且不能是连续的一段。问能选取的子序列的方案数 模上 109+710^9 + 7​ 的值。

题目提示

所求方案数 == 位置对称的回文子序列数 - 回文子串数。

回文子串数可以使用 manacher 解决。(当然也可以使用二分+哈希,不过 manacher 是 O(n)O(n),而二分+哈希是 O(nlogn)O(n \log n)

而位置对称的回文子序列数可以使用 FFT/NTT 解决(可以类比上一题)。

于是时间复杂度为 O(nlogn)O(n \log n),可以通过此题。

拓展

前置知识:多项式牛顿迭代

多项式牛顿迭代可以解决如下问题:给定 g(x)g(x) ,请在模 xnx^n 的意义下求出多项式 f(x)f(x) 满足 g(f(x))0(modxn)g(f(x)) \equiv 0 (\bmod x^n)

考虑倍增。

首先当 n=1n = 1 时,g(f(x))0(modxn)g(f(x)) \equiv 0 (\bmod x^n) 的解需要单独求出。

假设现在已经得到了模 xn2x^{\left\lceil\frac{n}{2}\right\rceil} 意义下的解 f0(x)f_0(x),要求模 xnx^n 意义下的解 f(x)f(x)

g(f(x))g(f(x))f0(x)f_0(x) 处进行泰勒展开,有:g(f(x))=i=0+g(i)(f0(x))i!(f(x)f0(x))i\begin{aligned} g(f(x)) = \sum_{i=0}^{+\infty}\frac{g^{(i)}(f_{0}(x))}{i!}(f(x)-f_{0}(x))^{i} \end{aligned} ,其中 g(i)(x)g^{(i)}(x) 表示 g(x)g(x)ii 阶导。

因为 f(x)f0(x)0(modxn2)f(x) - f_0(x) \equiv 0 (\bmod x^{\left\lceil\frac{n}{2}\right\rceil}) ,所以对于任意 i2i \ge 2 ,都有 (f(x)f0(x))i0(modxn)(f(x) - f_0(x))^i \equiv 0 (\bmod x^n)

于是 g(f(x))0(modxn)g(f0(x))g(f0(x))(f(x)f0(x))0(modxn)g(f(x)) \equiv 0(\bmod x^n) \Leftrightarrow g(f_0(x)) - g'(f_0(x))(f(x) - f_0(x)) \equiv 0 (\bmod x^n)

得到 f(x)f0(x)g(f0(x))g(f0(x))(modxn)\begin{aligned} f(x)\equiv f_{0}(x)-\frac{g(f_{0}(x))}{g'(f_{0}(x))}(\bmod x^n) \end{aligned}

多项式求逆

题目链接

给定一个多项式 f(x)f(x) ,请求出一个多项式 g(x)g(x), 满足 f(x)g(x)1(modxn)f(x) \cdot g(x) \equiv 1 (\operatorname{mod} x^n)。系数对 998244353998244353 取模。其中 degf=degg=n1\operatorname{deg} f = \operatorname{deg} g = n - 1

这里,若一个多项式 f(x)1(modxn)f(x) \equiv 1 (\operatorname{mod} x^n) ,则 f(x)=g(x)xn+1f(x) = g(x) \cdot x^n + 1 ,其中 g(x)g(x) 为关于 xx 的多项式(这一行的 f,gf, g 与题干中的无关)。

题目解答

带入牛顿迭代。

定义函数 h(g(x))=1g(x)f(x)\begin{aligned} h(g(x)) = \frac 1 {g(x)} - f(x) \end{aligned} ,问题转化为在模 xnx^n 的意义下求出多项式 g(x)g(x) 满足 h(g(x))0(modxn)h(g(x)) \equiv 0 (\bmod x^n)

首先当 n=1n = 1 时,h(g(x))0(modxn)h(g(x)) \equiv 0 (\bmod x^n) 的解可以直接求出。

假设现在已经得到了模 xn2x^{\left\lceil\frac{n}{2}\right\rceil} 意义下的解 g0(x)g_0(x),要求模 xnx^n 意义下的解 g(x)g(x)

由牛顿迭代知:

g(x)g0(x)h(g0(x))h(g0(x))(modxn)g0(x)1g0(x)f(x)1g02(x)(modxn)2g0(x)f(x)g02(x)(modxn)\begin{aligned} g(x) &\equiv g_0(x) - \frac{h(g_0(x))}{h'(g_0(x))} &(\bmod x^n) \\ &\equiv g_0(x) - \frac{\frac 1 {g_0(x)} - f(x)}{-\frac 1 {g_0^2(x)}} &(\bmod x^n) \\ &\equiv 2g_0(x) - f(x)g_0^2(x) &(\bmod x^n) \end{aligned}

时间复杂度 T(n)=T(n2)+O(nlogn)=O(nlogn)T\left(n\right)=T\left(\frac{n}{2}\right)+O\left(n\log{n}\right)=O\left(n\log{n}\right)

这里贴一下代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1 << 22 | 5;
const LL mod = 998244353ll;
LL A[N], B[N]; // f, g
LL fA[N], fB[N], fAns[N];
int rev[23][N];
inline LL qpow(LL n, LL p, LL mod = 998244353ll) {
    LL ret = 1ll;
    while (p) {
        if (p & 1) (ret *= n) %= mod;
        (n *= n) %= mod; p >>= 1;
    }
    return ret;
}
void NTT(LL *F, int *Rev, int n, int Sign) { // Sign = 1: NTT , Sign = -1: INTT
	for (int i = 0; i < n; i++) {
		if (i < Rev[i]) swap(F[i], F[Rev[i]]);
	}
	for (int len = 2; len <= n; len <<= 1) {
		int mid = len >> 1;
        LL wl = qpow(3, (mod - 1) / len);
		for (int i = 0; i < n; i += len) {
			LL Ang = 1;
			for (int j = 0; j < mid; j++) {
				LL tx = F[i + j], ty = (F[i + j + mid] * Ang) % mod;
				F[i + j] = (tx + ty) % mod; F[i + j + mid] = (tx - ty + mod) % mod;
				Ang = (Ang * wl) % mod;
			}
		}
	}
    if (Sign == -1) {
        for (int i = 1; i < (n >> 1); i++) swap(F[i], F[n - i]);
        LL Invn = qpow(n, mod - 2);
        for (int i = 0; i < n; i++) {
            (F[i] *= Invn) %= mod;
        }
    }
}
LL Tmp[N]; // Tmp 用来存 f
void Polyinv(LL *f, LL *g, int n) {
    if (n == 1) {g[0] = qpow(f[0], mod - 2); return;}
    Polyinv(f, g, (n + 1) >> 1);
    int p = 0; while ((1 << p) <= n + n) p++;
    for (int i = 0; i < (1 << p); i++) Tmp[i] = (i < n) ? f[i] : 0; // 超过 n 的都设为 0
    NTT(Tmp, rev[p], 1 << p, 1); NTT(g, rev[p], 1 << p, 1);
    for (int i = 0; i < (1 << p); i++) Tmp[i] = ((g[i] * ((2ll - Tmp[i] * g[i]) % mod)) % mod + mod) % mod;
    NTT(Tmp, rev[p], 1 << p, -1);
    for (int i = 0; i < (1 << p); i++) g[i] = (i < n) ? Tmp[i] : 0; // 超过 n 的都设为 0, 否则可能影响后面结果.
}
int main() {
	int n; scanf("%d", &n);
	for (int i = 0; i < n; i++) scanf("%lld", A + i);
    for (int k = 1; k < 22; k++) {
        for (int i = 1, End = (1 << k); i < End; i++) {
            rev[k][i] = (rev[k][i >> 1] >> 1) | ((i & 1) << (k - 1));
        }
    }
    Polyinv(A, B, n);
    for (int i = 0; i < n; i++) {
        printf("%lld%c", B[i], " \n"[i == n - 1]);
    }
	return 0;
}

多项式开根

题目链接

给定一个多项式 f(x)f(x) ,请求出一个多项式 g(x)g(x), 满足 (g(x))2f(x)(modxn)(g(x))^2 \equiv f(x) (\operatorname{mod} x^n)。系数对 998244353998244353 取模。其中 degf=degg=n1\operatorname{deg} f = \operatorname{deg} g = n - 1。该题保证多项式的 00 次项为 11 (否则需要用到二次剩余)。

题目解答

带入牛顿迭代。

定义函数 h(g(x))=g2(x)f(x)\begin{aligned} h(g(x)) = g^2(x) - f(x) \end{aligned} ,问题转化为在模 xnx^n 的意义下求出多项式 g(x)g(x) 满足 h(g(x))0(modxn)h(g(x)) \equiv 0 (\bmod x^n)

首先当 n=1n = 1 时,h(g(x))0(modxn)h(g(x)) \equiv 0 (\bmod x^n) 的解可以直接求出。

假设现在已经得到了模 xn2x^{\left\lceil\frac{n}{2}\right\rceil} 意义下的解 g0(x)g_0(x),要求模 xnx^n 意义下的解 g(x)g(x)

由牛顿迭代知:

g(x)g0(x)h(g0(x))h(g0(x))(modxn)g0(x)g02(x)f(x)2g0(x)(modxn)g02(x)+f(x)2g0(x)(modxn)\begin{aligned} g(x) &\equiv g_0(x) - \frac{h(g_0(x))}{h'(g_0(x))} &(\bmod x^n) \\ &\equiv g_0(x) - \frac{g_0^2(x) - f(x)}{2g_0(x)} &(\bmod x^n) \\ &\equiv \frac{g_0^2(x) + f(x)}{2g_0(x)} &(\bmod x^n) \end{aligned}

时间复杂度 T(n)=T(n2)+O(nlogn)=O(nlogn)T\left(n\right)=T\left(\frac{n}{2}\right)+O\left(n\log{n}\right)=O\left(n\log{n}\right)

这里贴一下代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1 << 22 | 5;
const LL mod = 998244353ll, inv2 = 499122177ll;
LL A[N], B[N]; // f, g
LL fA[N], fB[N], fAns[N];
int rev[23][N];
inline LL qpow(LL n, LL p, LL mod = 998244353ll) {
    LL ret = 1ll;
    while (p) {
        if (p & 1) (ret *= n) %= mod;
        (n *= n) %= mod; p >>= 1;
    }
    return ret;
}
void NTT(LL *F, int *Rev, int n, int Sign) { // Sign = 1: NTT , Sign = -1: INTT
	for (int i = 0; i < n; i++) {
		if (i < Rev[i]) swap(F[i], F[Rev[i]]);
	}
	for (int len = 2; len <= n; len <<= 1) {
		int mid = len >> 1;
        LL wl = qpow(3, (mod - 1) / len);
		for (int i = 0; i < n; i += len) {
			LL Ang = 1;
			for (int j = 0; j < mid; j++) {
				LL tx = F[i + j], ty = (F[i + j + mid] * Ang) % mod;
				F[i + j] = (tx + ty) % mod; F[i + j + mid] = (tx - ty + mod) % mod;
				Ang = (Ang * wl) % mod;
			}
		}
	}
    if (Sign == -1) {
        for (int i = 1; i < (n >> 1); i++) swap(F[i], F[n - i]);
        LL Invn = qpow(n, mod - 2);
        for (int i = 0; i < n; i++) {
            (F[i] *= Invn) %= mod;
        }
    }
}
LL Tmp[N]; // 用来存 f
void Polyinv(LL *f, LL *g, int n) {
    if (n == 1) {g[0] = qpow(f[0], mod - 2); return;}
    Polyinv(f, g, (n + 1) >> 1);
    int p = 0; while ((1 << p) <= n + n) p++;
    for (int i = 0; i < (1 << p); i++) Tmp[i] = (i < n) ? f[i] : 0;
    NTT(Tmp, rev[p], 1 << p, 1); NTT(g, rev[p], 1 << p, 1);
    for (int i = 0; i < (1 << p); i++) Tmp[i] = ((g[i] * ((2ll - Tmp[i] * g[i]) % mod)) % mod + mod) % mod;
    NTT(Tmp, rev[p], 1 << p, -1);
    for (int i = 0; i < (1 << p); i++) g[i] = (i < n) ? Tmp[i] : 0;
}
LL ginv[N]; // 用来存 h 的逆元
void PolySqrt(LL *f, LL *g, int n) {
    if (n == 1) {g[0] = 1; return;}
    PolySqrt(f, g, (n + 1) >> 1);
    int p = 0; while ((1 << p) <= n + n) p++;
    for (int i = 0; i < (1 << p); i++) ginv[i] = 0;
    Polyinv(g, ginv, n);
    for (int i = 0; i < (1 << p); i++) Tmp[i] = (i < n) ? f[i] : 0;
    NTT(Tmp, rev[p], 1 << p, 1); NTT(g, rev[p], 1 << p, 1); NTT(ginv, rev[p], 1 << p, 1);
    for (int i = 0; i < (1 << p); i++) Tmp[i] = (inv2 * ((g[i] + ginv[i] * Tmp[i]) % mod)) % mod;
    NTT(Tmp, rev[p], 1 << p, -1);
    for (int i = 0; i < (1 << p); i++) g[i] = (i < n) ? Tmp[i] : 0;
}
int main() {
	int n; scanf("%d", &n);
	for (int i = 0; i < n; i++) scanf("%lld", A + i);
    for (int k = 1; k < 22; k++) {
        for (int i = 1, End = (1 << k); i < End; i++) {
            rev[k][i] = (rev[k][i >> 1] >> 1) | ((i & 1) << (k - 1));
        }
    }
    PolySqrt(A, B, n);
    for (int i = 0; i < n; i++) {
        printf("%lld%c", B[i], " \n"[i == n - 1]);
    }
	return 0;
}

请注意,如果该题多项式的 00 次项不为 11 ,就需要使用二次剩余来计算 g(x)g(x)00 次项。

任意模数多项式乘法 (MTT)

题目链接

题目大意

给定两个多项式 f(x),g(x)f(x), g(x) ,请求出 f(x)g(x)f(x) * g(x) 。系数对 pp 取模,且不保证 pp 可以分解成 p=a2k+1p = a \cdot 2^k + 1 之形式。

三模数 NTT

任意模数 NTT,最大的数为 p2×max{n,m}1023p^2\times\max\{n,m\}\leq 10^{23} ,所以一般选三个模数即可,求出这三个模数下的答案,然后中国剩余定理即可。

假设这一位需要求得的答案是 xx,三个模数分别为 A,B,CA,B,C ,那么:

xx1(modA), xx2(modB), xx3(modC)x \equiv x_1 (\operatorname{mod} A),\ x \equiv x_2 (\operatorname{mod} B),\ x \equiv x_3 (\operatorname{mod} C)

先解出前两个:xx1(modA), xx2(modB)x \equiv x_1 (\operatorname{mod} A),\ x \equiv x_2 (\operatorname{mod} B) \Rightarrowx1+k1A=x2+k2Bx_1 + k_1 \cdot A = x_2 + k_2 \cdot Bx1+k1Ax2(modB)x_1 + k_1 \cdot A \equiv x_2 (\operatorname{mod} B)

所以 k1x2x1A(modB)k_1 \equiv \frac{x_2 - x_1}{A} (\operatorname{mod} B),这样就求出了 k1k_1。也就求出了 xx1+k1A(modAB)x \equiv x_1 + k_1 \cdot A (\operatorname{mod} AB),令 x4=x1+k1Ax_4=x_1+k_1 \cdot A

现在就变成了以下两个式子:xx4=x1+k1A(modAB), xx3(modC)x \equiv x_4 = x_1 + k_1 \cdot A (\operatorname{mod} AB),\ x \equiv x_3 (\operatorname{mod} C)

x4+k4AB=x3+k3Cx_4 + k_4 \cdot AB = x_3 + k_3 \cdot Cx4+k4ABx3(modC)x_4 + k_4 \cdot AB \equiv x_3 (\operatorname{mod} C)

所以 k4x3x4AB(modC)k_4 \equiv \frac{x_3 - x_4}{AB} (\operatorname{mod} C),这样就求出了 k4k_4。也就求出了 xx4+k4AB(modABC)x \equiv x_4 + k_4 \cdot AB (\operatorname{mod} ABC)。因为我们要求的 xx 小于 ABCABC,而 x4+k4AB<ABCx_4 + k_4 \cdot AB < ABC,所以 x=x4+k4ABx = x_4 + k_4 \cdot AB,就做完了。共 9 次 NTT,常数巨大,最慢的点 847ms。代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1 << 22 | 5;
const LL mod1 = 998244353ll, mod2 = 1004535809ll, mod3 = 469762049ll;
LL A[N], B[N], mod;
int rev[23][N];
inline LL qpow(LL n, LL p, LL mod = 998244353ll) {
    LL ret = 1ll;
    while (p) {
        if (p & 1) (ret *= n) %= mod;
        (n *= n) %= mod; p >>= 1;
    }
    return ret;
}
void NTT(LL *F, int *Rev, int n, int Sign, LL mod = 998244353ll) { // Sign = 1: NTT , Sign = -1: INTT
	for (int i = 0; i < n; i++) {
		if (i < Rev[i]) swap(F[i], F[Rev[i]]);
	}
	for (int len = 2; len <= n; len <<= 1) {
		int mid = len >> 1;
        LL wl = qpow(3, (mod - 1) / len, mod);
		for (int i = 0; i < n; i += len) {
			LL Ang = 1;
			for (int j = 0; j < mid; j++) {
				LL tx = F[i + j], ty = (F[i + j + mid] * Ang) % mod;
				F[i + j] = (tx + ty) % mod; F[i + j + mid] = (tx - ty + mod) % mod;
				Ang = (Ang * wl) % mod;
			}
		}
	}
    if (Sign == -1) {
        for (int i = 1; i < (n >> 1); i++) swap(F[i], F[n - i]);
        LL Invn = qpow(n, mod - 2, mod);
        for (int i = 0; i < n; i++) {
            (F[i] *= Invn) %= mod;
        }
    }
}
LL A1[N], A2[N], A3[N];
LL B1[N], B2[N], B3[N];
LL C1[N], C2[N], C3[N];
LL F(LL c1, LL c2, LL c3, LL mod) {
    // mod1, mod2, mod3, mod 分别代表 A, B, C, p
    // c1, c2, c3 分别代表 x1, x2, x3
    LL k1 = ((c2 - c1 + mod2) * qpow(mod1, mod2 - 2, mod2)) % mod2;
    LL x4 = c1 + k1 * mod1;
    LL k4 = ((((c3 - x4) % mod3 + mod3) % mod3) * qpow((mod1 * mod2) % mod3, mod3 - 2, mod3)) % mod3;
    return (x4 + k4 * ((mod1 * mod2) % mod)) % mod;
}
int main() {
	int n, m, p = 0;
	scanf("%d%d%lld", &n, &m, &mod);
	for (int i = 0; i <= n; i++) scanf("%lld", A + i), A1[i] = A2[i] = A3[i] = A[i];
	for (int i = 0; i <= m; i++) scanf("%lld", B + i), B1[i] = B2[i] = B3[i] = B[i];
	while ((1 << p) <= n + m) p++;
    for (int k = 1; k < 22; k++) {
        for (int i = 1, End = (1 << k); i < End; i++) {
            rev[k][i] = (rev[k][i >> 1] >> 1) | ((i & 1) << (k - 1));
        }
    }
	NTT(A1, rev[p], 1 << p, 1, mod1); NTT(A2, rev[p], 1 << p, 1, mod2); NTT(A3, rev[p], 1 << p, 1, mod3);
	NTT(B1, rev[p], 1 << p, 1, mod1); NTT(B2, rev[p], 1 << p, 1, mod2); NTT(B3, rev[p], 1 << p, 1, mod3);
    for (int i = 0, End = (1 << p); i < End; i++) {
        C1[i] = A1[i] * B1[i]; C2[i] = A2[i] * B2[i]; C3[i] = A3[i] * B3[i];
    }
	NTT(C1, rev[p], 1 << p, -1, mod1); NTT(C2, rev[p], 1 << p, -1, mod2); NTT(C3, rev[p], 1 << p, -1, mod3);
    for (int i = 0; i <= n + m; i++) {
        printf("%lld%c", F(C1[i], C2[i], C3[i], mod), " \n"[i == n + m]);
    }
	return 0;
}

拆系数 FFT

k=pk = \sqrt{p}(这里可以固定为 32768) ,多项式 f(x)=kA(x)+B(x)f(x)=k \cdot A(x)+B(x)g(x)=kC(x)+D(x)g(x)=k \cdot C(x)+D(x),然后 f(x)g(x)=k2A(x)C(x)+k(A(x)D(x)+B(x)C(x))+B(x)D(x)f(x) \cdot g(x) = k^2 \cdot A(x)C(x) + k \cdot (A(x)D(x) + B(x)C(x)) + B(x)D(x)。需进行 8 次 FFT(4 次将 A(x),B(x),C(x),D(x)A(x), B(x), C(x), D(x) 用点值表示法表示,4 次将 A(x)C(x),A(x)D(c),B(x)C(x),B(x)D(x)A(x)C(x), A(x)D(c), B(x)C(x), B(x)D(x) 从点值表示法转换为系数表示法)。常数很大,于是考虑优化。

为啥要拆系数?是为了将系数的位数变小。原本卷积后产生的最大的数约为 109×109×105=102310^9 \times 10^9 \times 10^5 = 10^{23},拆系数后,卷积后产生的最大的数约为 215×215×105=10142^{15} \times 2^{15} \times 10^5 = 10^{14}。可以使用常规数据类型( 如 long double )进行储存。

FFT “二合一”优化

假如要对多项式 F(x)F(x)G(x)G(x) 进行 DFT,则考虑构造如下多项式 P(x)P(x),满足 P(x)=F(x)+G(x)iP(x) = F(x) + G(x) \cdot \text{i}

可以发现,P(ωnk)P(\omega_n^k) 的第 xx 项为

(Fx+Gxi)ωnxk=(Fx+Gxi)(cos(2πxkn)+sin(2πxkn)i)=Fx(cos(2πxkn)+sin(2πxkn)i)+Gx(cos(2πxkn)isin(2πxkn))=(Fxcos(2πxkn)Gxsin(2πxkn))+(Fxsin(2πxkn)+Gxcos(2πxkn))i\begin{aligned} (F_x + G_x \cdot \text{i}) \cdot \omega_n^{xk} & = (F_x + G_x \cdot \text{i}) \cdot (\cos(\frac{2\pi x k}{n}) + \sin(\frac{2\pi x k}{n}) \cdot \text{i}) \\ & = F_x \cdot (\cos(\frac{2\pi x k}{n}) + \sin(\frac{2\pi x k}{n}) \cdot \text{i}) + G_x \cdot (\cos(\frac{2\pi x k}{n}) \cdot \text{i} - \sin(\frac{2\pi x k}{n})) \\ & = (F_x \cdot \cos(\frac{2\pi x k}{n}) - G_x \cdot \sin(\frac{2\pi x k}{n})) + (F_x \cdot \sin(\frac{2\pi x k}{n}) + G_x \cdot \cos(\frac{2\pi x k}{n})) \cdot \text{i} \end{aligned}

P(ωnk)P(\omega_n^{-k}) 的第 xx 项为

(Fx+Gxi)ωnxk=(Fx+Gxi)(cos(2π(xk)n)+sin(2π(xk)n)i)=Fx(cos(2πxkn)sin(2πxkn)i)+Gx(cos(2πxkn)i+sin(2πxkn))=(Fxcos(2πxkn)+Gxsin(2πxkn))+(Fxsin(2πxkn)+Gxcos(2πxkn))i\begin{aligned} (F_x + G_x \cdot \text{i}) \cdot \omega_n^{-xk} & = (F_x + G_x \cdot \text{i}) \cdot (\cos(\frac{2\pi \cdot (-x k)}{n}) + \sin(\frac{2\pi \cdot (-x k)}{n}) \cdot \text{i}) \\ & = F_x \cdot (\cos(\frac{2\pi x k}{n}) - \sin(\frac{2\pi x k}{n}) \cdot \text{i}) + G_x \cdot (\cos(\frac{2\pi x k}{n}) \cdot \text{i} + \sin(\frac{2\pi x k}{n})) \\ & = (F_x \cdot \cos(\frac{2\pi x k}{n}) + G_x \cdot \sin(\frac{2\pi x k}{n})) + (-F_x \cdot \sin(\frac{2\pi x k}{n}) + G_x \cdot \cos(\frac{2\pi x k}{n})) \cdot \text{i} \end{aligned}

F(ωnk)F(\omega_n^k) 的第 xx 项为 Fxcos(2πxkn)+Fxsin(2πxkn)iF_x \cdot \cos(\frac{2\pi x k}{n}) + F_x \cdot \sin(\frac{2\pi x k}{n}) \cdot \text{i}G(ωnk)G(\omega_n^k) 的第 xx 项为 Gxcos(2πxkn)+Gxsin(2πxkn)iG_x \cdot \cos(\frac{2\pi x k}{n}) + G_x \cdot \sin(\frac{2\pi x k}{n}) \cdot \text{i}

所以 F(ωnk)=12(Re(P(ωnk))+Re(P(ωnk)))+12(Im(P(ωnk))Im(P(ωnk)))i\begin{aligned} F(\omega_n^k) = \frac{1}{2} \cdot (\operatorname{Re}(P(\omega_n^k)) + \operatorname{Re}(P(\omega_n^{-k}))) + \frac{1}{2} \cdot (\operatorname{Im}(P(\omega_n^k)) - \operatorname{Im}(P(\omega_n^{-k}))) \cdot \text{i} \end{aligned}

以及有

G(ωnk)=(12(Re(P(ωnk))Re(P(ωnk)))+12(Im(P(ωnk))+Im(P(ωnk)))i)(i)=12(Im(P(ωnk))+Im(P(ωnk)))12(Re(P(ωnk))Re(P(ωnk)))i\begin{aligned} G(\omega_n^k) & = (\frac{1}{2} \cdot (\operatorname{Re}(P(\omega_n^k)) - \operatorname{Re}(P(\omega_n^{-k}))) + \frac{1}{2} \cdot (\operatorname{Im}(P(\omega_n^k)) + \operatorname{Im}(P(\omega_n^{-k}))) \cdot \text{i}) \cdot (- \text{i}) \\ & = \frac{1}{2} \cdot (\operatorname{Im}(P(\omega_n^k)) + \operatorname{Im}(P(\omega_n^{-k}))) - \frac{1}{2} \cdot (\operatorname{Re}(P(\omega_n^k)) - \operatorname{Re}(P(\omega_n^{-k}))) \cdot \text{i} \end{aligned}

因此,对 PP 进行一次 DFT 即可。

这样子,我们就只需要进行一次 FFT 就可以对两个多项式进行 DFT 了。

使用上面的方法,就可以使用 2 次 FFT 求出 A(x),B(x),C(x),D(x)A(x), B(x), C(x), D(x) 的点值表达式了。

同样,假如要对多项式 F(x)F(x)G(x)G(x) 进行 IDFT,则同样可以考虑构造如下多项式 P(x)P(x),满足 P(x)=F(x)+G(x)iP(x) = F(x) + G(x) \cdot \text{i}

PP 进行一次 IDFT 即可。这样,多项式 F(x)F(x) 就在实部,G(x)G(x) 就在虚部。

这样子,我们就只需要进行一次 IFFT 就可以对两个多项式进行 IDFT 了。

使用上面的方法,就可以使用 2 次 FFT 求出 A(x)C(x), A(x)D(x), B(x)C(x), B(x)D(x)A(x) \cdot C(x),\ A(x) \cdot D(x),\ B(x) \cdot C(x),\ B(x) \cdot D(x) 的点值表达式了。

总共只用了 4 次 FFT。而通常来讲,2 次 FFT 和 3 次 NTT 时间差不多(亲测),因此 4 次 FFT 比 9 次 NTT 快出 3 个 NTT(或者 2 个 FFT),比较快。

不过 FFT 对精度的要求较高,因此不用加优化的地方不用乱加,防止精度出问题。

参考代码:(因为精度问题而被卡成了 50 pts50\ pts

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1 << 20 | 1;
struct Complex {
	long double x, y; // x + yi
	Complex() {x = y = 0.0;}
	Complex(long double xx, long double yy) {x = xx; y = yy;}
};
Complex operator + (Complex A, Complex B) {
	return Complex(A.x + B.x, A.y + B.y);
}
Complex operator - (Complex A, Complex B) {
	return Complex(A.x - B.x, A.y - B.y);
}
Complex operator * (Complex A, Complex B) {
	return Complex(A.x * B.x - A.y * B.y, A.x * B.y + A.y * B.x);
}
LL t1[N], t2[N], mod, t3[N];
Complex A[N], B[N], C[N], D[N];
Complex AC[N], AD[N], BC[N], BD[N];
int Rev[N];
void FFT(Complex *F, int n, int Sign) { // Sign = 1: FFT , Sign = -1: IFFT
	for (int i = 0; i < n; i++) {
		if (i < Rev[i]) swap(F[i], F[Rev[i]]);
	}
	for (int len = 2; len <= n; len <<= 1) {
		int mid = len >> 1;
		Complex wl = Complex(cos(M_PI / mid), Sign * sin(M_PI / mid));
		for (int i = 0; i < n; i += len) {
			Complex Ang = Complex(1, 0);
			for (int j = 0; j < mid; j++) {
				Complex tx = F[i + j], ty = F[i + j + mid];
				F[i + j] = tx + Ang * ty; F[i + j + mid] = tx - Ang * ty;
				Ang = Ang * wl;
			}
		}
	}
    if (Sign == -1) {
        for (int i = 0; i < n; i++) F[i] = Complex(F[i].x / n, F[i].y / n);
    }
}
Complex P[N];
void PairFFT(Complex *F, Complex *G, int n) {
    for (int i = 0; i < n; i++) P[i] = Complex(F[i].x, G[i].x);
    FFT(P, n, 1); P[n] = P[0];
    for (int i = 0; i < n; i++) F[i] = Complex(0.5 * (P[i].x + P[n - i].x), 0.5 * (P[i].y - P[n - i].y));
    for (int i = 0; i < n; i++) G[i] = Complex(0.5 * (P[i].y + P[n - i].y), -0.5 * (P[i].x - P[n - i].x));
}
void PairIFFT(Complex *F, Complex *G, int n) {
    for (int i = 0; i < n; i++) P[i] = F[i] + G[i] * Complex(0, 1);
    FFT(P, n, -1);
    for (int i = 0; i < n; i++) F[i] = Complex(P[i].x, 0), G[i] = Complex(P[i].y, 0);
}
int main() {
	int n, m, p = 0;
	scanf("%d%d%lld", &n, &m, &mod);
	for (int i = 1; i <= n + 1; i++) scanf("%lld", t1 + i);
	for (int i = 1; i <= m + 1; i++) scanf("%lld", t2 + i);
	while ((1 << p) <= n + m) p++;
	for (int i = 1, End = (1 << p); i < End; i++) {
		Rev[i] = (Rev[i >> 1] >> 1) | ((i & 1) << (p - 1));
	}
    for (int i = 1; i <= n + 1; i++) A[i - 1] = Complex(t1[i] >> 15, 0), B[i - 1] = Complex(t1[i] & 32767, 0);
    for (int i = 1; i <= m + 1; i++) C[i - 1] = Complex(t2[i] >> 15, 0), D[i - 1] = Complex(t2[i] & 32767, 0);
    PairFFT(A, B, 1 << p); PairFFT(C, D, 1 << p);
	for (int i = 0, End = (1 << p); i < End; i++) {
        AC[i] = A[i] * C[i]; AD[i] = A[i] * D[i];
        BC[i] = B[i] * C[i]; BD[i] = B[i] * D[i];
	}
    PairIFFT(AC, AD, 1 << p); PairIFFT(BC, BD, 1 << p);
	for (int i = 0; i <= n + m; i++) {
		printf("%lld%c", (((((1ll << 30) * ((LL)(AC[i].x + 0.5) % mod)) % mod) + ((1ll << 15) * (((LL)(AD[i].x + 0.5) + (LL)(BC[i].x + 0.5)) % mod) % mod) + (LL)(BD[i].x + 0.5)) % mod + mod) % mod, " \n"[i == n + m]);
	}
	return 0;
}

【模板】分治 FFT

题目链接

题目描述

给定序列 g1,n1g_{1, \cdots n - 1} ,求序列 f0,n1f_{0, \cdots n - 1} ,满足 fi=j=1ifijgjf_i = \sum_{j = 1}^i f_{i - j} g_j,其中 1in11 \le i \le n - 1f0=1f_0 = 1 。答案对 998244353998244353 取模。

数据范围 1n1051 \le n \le 10^50gi<9982443530 \le g_i < 998244353 。时限 5s5s

第一个整数表示 nn ,后面 n1n - 1 个整数,表示 g1,g2,gn1g_1, g_2, \cdots g_{n - 1}

样例1输入:4\n3 1 2 ,输出:1 3 10 35

题目解答

考虑分治。发现题目的要求类似于卷积,于是考虑使用 FFT/NTT 。

但是后面的数字基于前面的数字,无法快速计算,时间复杂度退化至 O(n2)O(n^2)

于是我们考虑将类似的转移同时进行,来节省复杂度。

考虑利用分治。(也就是 cdq 分治)

(以下默认 n=2m,mNn = 2^m, m \in \mathbb{N} 。对于 nn22 的幂次的情况,在序列后面补 00 即可)

在算一段区间 [l,r)[l, r) 时,将其拆成两个长度一样的区间 [l,mid)[l, mid)[mid,r)[mid, r) ,其中 mid=l+r2mid = \frac{l + r}2 。先求出 [l,mid)[l, mid) 之间的答案,之后去计算 [l,mid)[l, mid) 内的 ff[mid,r)[mid, r) 内的 ff 的贡献。假设其对 fxf_x 的贡献为 wxw_x ,其中 x[mid,r)x \in [mid, r) ,则有:wx=i[l,mid)gxifiw_x = \sum_{i \in [l, mid)} g_{x - i} \cdot f_i 。这部分可以利用卷积来快速计算。计算完以后,答案直接加到答案数组就可以了。

我们可以用样例来方便理解:给定 $g = $ [0, 3, 1, 2] ,求 ff 。(这里 g0=0g_0 = 0

一开始,f0=1f_0 = 1 。$f = $ [1, 0, 0, 0]

将其分为两段:[1, 0|0, 0]

先算左边的:[1, 0] ,将其分开:[1|0] 。因为左半边的长度为 11,于是不往下递归。计算左区间对右区间的贡献。就是将 [1, 0]gg 的前 22[0, 3] 做卷积,得到 [*, 3] ,用 * 表示这个数我们不在乎。将卷积的后半段 [3] 加到这个区间的右半边即可。

操作后得到:[1, 3] ,因为右半边长度为 11,于是不往下递归,这区间就做完了,回到上一步。

现在 $f = $ [1, 3|0, 0] ,和上面一样,计算左区间对右区间的贡献。就是将 [1, 3, 0, 0]gg 的前 22[0, 3, 1, 2] 做卷积,得到 [*, *, 10, 5] ,将卷积的后半段 [10, 5] 加到这个区间的右半边即可。

操作后得到:[1, 3, 10, 5] ,现在开始计算这个区间的右半段 [10, 5]

将其分开:[10|5] 。和之前一样,因为左半边的长度为 11,于是不往下递归。计算左区间对右区间的贡献。因为是左区间右区间的贡献,所以在算卷积时只能要前半段的 ff ,后半段均设为 00。也就是说,贡献就是将 [10, 0]gg 的前 22[0, 3] 做卷积,得到 [*, 30] ,将卷积的后半段 [30] 加到这个区间的右半边即可。

操作后得到:[10|35] 因为右半边长度为 11,于是不往下递归,这区间就做完了,回到上一步。

因此得到最终答案:[1, 3, 10, 35]

这里贴一下代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1 << 22 | 5;
const LL mod = 998244353ll;
LL A[N], B[N]; // g, f
LL fA[N], fB[N], fAns[N];
int rev[23][N];
inline LL qpow(LL n, LL p, LL mod = 998244353ll) {
    LL ret = 1ll;
    while (p) {
        if (p & 1) (ret *= n) %= mod;
        (n *= n) %= mod; p >>= 1;
    }
    return ret;
}
void NTT(LL *F, int *Rev, int n, int Sign) { // Sign = 1: NTT , Sign = -1: INTT
	for (int i = 0; i < n; i++) {
		if (i < Rev[i]) swap(F[i], F[Rev[i]]);
	}
	for (int len = 2; len <= n; len <<= 1) {
		int mid = len >> 1;
        LL wl = qpow(3, (mod - 1) / len);
		for (int i = 0; i < n; i += len) {
			LL Ang = 1;
			for (int j = 0; j < mid; j++) {
				LL tx = F[i + j], ty = (F[i + j + mid] * Ang) % mod;
				F[i + j] = (tx + ty) % mod; F[i + j + mid] = (tx - ty + mod) % mod;
				Ang = (Ang * wl) % mod;
			}
		}
	}
    if (Sign == -1) {
        for (int i = 1; i < (n >> 1); i++) swap(F[i], F[n - i]);
        LL Invn = qpow(n, mod - 2);
        for (int i = 0; i < n; i++) {
            (F[i] *= Invn) %= mod;
        }
    }
}
void solve(int l, int r, int Log2) { // [l, r)
    if (Log2 == 0) return;
    int mid = l + ((r - l) >> 1);
    solve(l, mid, Log2 - 1);
    for (int i = l; i < r; i++) fA[i - l] = A[i - l];
    for (int i = l; i < mid; i++) fB[i - l] = B[i];
    for (int i = mid; i < r; i++) fB[i - l] = 0;
    for (int i = (1 << (Log2)); i < (1 << (Log2 + 1)); i++) fA[i] = fB[i] = 0;
	NTT(fA, rev[Log2 + 1], 1 << (Log2 + 1), 1); NTT(fB, rev[Log2 + 1], 1 << (Log2 + 1), 1);
    for (int i = 0, End = (1 << (Log2 + 1)); i < End; i++) {
        fAns[i] = (fA[i] * fB[i]) % mod;
    }
    NTT(fAns, rev[Log2 + 1], 1 << (Log2 + 1), -1);
    for (int i = mid; i < r; i++) (B[i] += fAns[i - l]) %= mod;
    solve(mid, r, Log2 - 1);
}
int main() {
	int n, m, p = 0;
	scanf("%d", &n);
	for (int i = 1; i < n; i++) scanf("%lld", A + i); B[0] = 1;
	while ((1 << p) <= n + n) p++;
    for (int k = 1; k < 22; k++) {
        for (int i = 1, End = (1 << k); i < End; i++) {
            rev[k][i] = (rev[k][i >> 1] >> 1) | ((i & 1) << (k - 1));
        }
    }
    solve(0, 1 << p, p);
    for (int i = 0; i < n; i++) {
        printf("%lld%c", B[i], " \n"[i == n - 1]);
    }
	return 0;
}

那么看到这里应该会有疑问:题目背景中说的 “也可用多项式求逆解决。” 是什么意思。

就如字面意思,可以用多项式求逆解决。

F(x)=i=0+fixi,G(x)=i=0+gixi\begin{aligned} F(x) = \sum_{i = 0}^{+\infty} f_ix^i, G(x) = \sum_{i = 0}^{+\infty} g_ix^i \end{aligned}​ 以及 g0=0g_0 = 0 ,那么有

那么有 F(x)G(x)=i=0+(j=0ifjgij)xi=j=00f0g0x0+i=1+fixi=F(x)f0x0\begin{aligned} F(x) G(x) = \sum_{i = 0}^{+\infty} (\sum_{j = 0}^i f_jg_{i - j}) x^i = \sum_{j = 0}^0 f_0g_0x^0 + \sum_{i = 1}^{+\infty} f_ix^i = F(x) - f_0x^0 \end{aligned}​​​ 。

由于 f0=1f_0 = 1​ ,那么有 F(x)G(x)F(x)f0 (modxn)F(x)G(x) \equiv F(x) - f_0\ (\operatorname{mod} x^n)​ 。

F(x)(1G(x))f0(modxn)F(x)f01G(x)(modxn)F(x)(1 - G(x)) \equiv f_0 (\bmod x^n) \Rightarrow \begin{aligned} F(x) \equiv \frac{f_0}{1 - G(x)} (\bmod x^n) \end{aligned}

于是 F(x)(1G(x))1(modxn)F(x) \equiv (1 - G(x))^{-1} (\bmod x^n)​ 。那么就是一个多项式求逆的模板了。代码(和多项式求逆几乎一样):

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1 << 22 | 5;
const LL mod = 998244353ll;
LL A[N], B[N]; // f, g
LL fA[N], fB[N], fAns[N];
int rev[23][N];
inline LL qpow(LL n, LL p, LL mod = 998244353ll) {
    LL ret = 1ll;
    while (p) {
        if (p & 1) (ret *= n) %= mod;
        (n *= n) %= mod; p >>= 1;
    }
    return ret;
}
void NTT(LL *F, int *Rev, int n, int Sign) { // Sign = 1: NTT , Sign = -1: INTT
	for (int i = 0; i < n; i++) {
		if (i < Rev[i]) swap(F[i], F[Rev[i]]);
	}
	for (int len = 2; len <= n; len <<= 1) {
		int mid = len >> 1;
        LL wl = qpow(3, (mod - 1) / len);
		for (int i = 0; i < n; i += len) {
			LL Ang = 1;
			for (int j = 0; j < mid; j++) {
				LL tx = F[i + j], ty = (F[i + j + mid] * Ang) % mod;
				F[i + j] = (tx + ty) % mod; F[i + j + mid] = (tx - ty + mod) % mod;
				Ang = (Ang * wl) % mod;
			}
		}
	}
    if (Sign == -1) {
        for (int i = 1; i < (n >> 1); i++) swap(F[i], F[n - i]);
        LL Invn = qpow(n, mod - 2);
        for (int i = 0; i < n; i++) {
            (F[i] *= Invn) %= mod;
        }
    }
}
LL Tmp[N];
void Polyinv(LL *f, LL *g, int n) {
    if (n == 1) {g[0] = qpow(f[0], mod - 2); return;}
    Polyinv(f, g, (n + 1) >> 1);
    int p = 0; while ((1 << p) <= n + n) p++;
    for (int i = 0; i < (1 << p); i++) Tmp[i] = (i < n) ? f[i] : 0;
    NTT(Tmp, rev[p], 1 << p, 1); NTT(g, rev[p], 1 << p, 1);
    for (int i = 0; i < (1 << p); i++) Tmp[i] = ((g[i] * ((2ll - Tmp[i] * g[i]) % mod)) % mod + mod) % mod;
    NTT(Tmp, rev[p], 1 << p, -1);
    for (int i = 0; i < (1 << p); i++) g[i] = (i < n) ? Tmp[i] : 0;
}
int main() {
	int n, m, p = 0;
	scanf("%d", &n);
	for (int i = 1; i < n; i++) scanf("%lld", A + i), A[i] = -A[i];
    A[0] = 1;
    for (int k = 1; k < 22; k++) {
        for (int i = 1, End = (1 << k); i < End; i++) {
            rev[k][i] = (rev[k][i >> 1] >> 1) | ((i & 1) << (k - 1));
        }
    }
    Polyinv(A, B, n);
    for (int i = 0; i < n; i++) {
        printf("%lld%c", B[i], " \n"[i == n - 1]);
    }
	return 0;
}

多项式对数函数(多项式 ln)

题目链接

给定一个多项式 f(x)f(x)​ ,请求出一个多项式 g(x)g(x)​, 满足 g(x)lnf(x)(modxn)g(x) \equiv \ln f(x) (\operatorname{mod} x^n)​。系数对 998244353998244353​ 取模。其中 degf=degg=n1\operatorname{deg} f = \operatorname{deg} g = n - 1​。

题目解答

g(x)=lnf(x)g(x) = \ln f(x) 两边同时求导,有:g(x)=f(x)f(x)\begin{aligned} g'(x) = \frac{f'(x)}{f(x)} \end{aligned}​。

于是可以利用多项式求逆和多项式求导算出 gg 的导数,之后对 gg' 求不定积分即可。

补充:求导公式(xa)=axa1(x^{a})'=ax^{a-1}​​​,不定积分公式(求导的逆运算):xadx=1a+1xa+1\int x^a\text{d}x=\frac{1}{a+1}x^{a+1}​​。

这里贴一下代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1 << 22 | 5;
const LL mod = 998244353ll, inv2 = 499122177ll;
LL A[N], B[N]; // f, g
LL fA[N], fB[N], fAns[N];
int rev[23][N];
inline LL qpow(LL n, LL p, LL mod = 998244353ll) {
    LL ret = 1ll;
    while (p) {
        if (p & 1) (ret *= n) %= mod;
        (n *= n) %= mod; p >>= 1;
    }
    return ret;
}
void NTT(LL *F, int *Rev, int n, int Sign) { // Sign = 1: NTT , Sign = -1: INTT
	for (int i = 0; i < n; i++) {
		if (i < Rev[i]) swap(F[i], F[Rev[i]]);
	}
	for (int len = 2; len <= n; len <<= 1) {
		int mid = len >> 1;
        LL wl = qpow(3, (mod - 1) / len);
		for (int i = 0; i < n; i += len) {
			LL Ang = 1;
			for (int j = 0; j < mid; j++) {
				LL tx = F[i + j], ty = (F[i + j + mid] * Ang) % mod;
				F[i + j] = (tx + ty) % mod; F[i + j + mid] = (tx - ty + mod) % mod;
				Ang = (Ang * wl) % mod;
			}
		}
	}
    if (Sign == -1) {
        for (int i = 1; i < (n >> 1); i++) swap(F[i], F[n - i]);
        LL Invn = qpow(n, mod - 2);
        for (int i = 0; i < n; i++) {
            (F[i] *= Invn) %= mod;
        }
    }
}
LL Tmp[N];
void Polyinv(LL *f, LL *g, int n) {
    if (n == 1) {g[0] = qpow(f[0], mod - 2); return;}
    Polyinv(f, g, (n + 1) >> 1);
    int p = 0; while ((1 << p) <= n + n) p++;
    for (int i = 0; i < (1 << p); i++) Tmp[i] = (i < n) ? f[i] : 0;
    NTT(Tmp, rev[p], 1 << p, 1); NTT(g, rev[p], 1 << p, 1);
    for (int i = 0; i < (1 << p); i++) Tmp[i] = ((g[i] * ((2ll - Tmp[i] * g[i]) % mod)) % mod + mod) % mod;
    NTT(Tmp, rev[p], 1 << p, -1);
    for (int i = 0; i < (1 << p); i++) g[i] = (i < n) ? Tmp[i] : 0;
}
LL finv[N];
void PolyLn(LL *f, LL *g, int n) {
    int p = 0; while ((1 << p) <= n + n) p++;
    for (int i = 0; i < (1 << p); i++) finv[i] = 0;
    Polyinv(f, finv, n);
    for (int i = 0; i < n - 1; i++) g[i] = (f[i + 1] * (i + 1)) % mod;
    for (int i = n - 1; i < (1 << p); i++) g[i] = 0;
    for (int i = 0; i < (1 << p); i++) Tmp[i] = g[i];
    NTT(finv, rev[p], 1 << p, 1); NTT(Tmp, rev[p], 1 << p, 1);
    for (int i = 0; i < (1 << p); i++) Tmp[i] = (finv[i] * Tmp[i]) % mod;
    NTT(Tmp, rev[p], 1 << p, -1);
    g[0] = 0;
    for (int i = 1; i < (1 << p); i++) g[i] = (i < n) ? ((Tmp[i - 1] * qpow(i, mod - 2)) % mod) : 0;
}
int main() {
	int n, m, p = 0;
	scanf("%d", &n);
	for (int i = 0; i < n; i++) scanf("%lld", A + i);
    for (int k = 1; k < 22; k++) {
        for (int i = 1, End = (1 << k); i < End; i++) {
            rev[k][i] = (rev[k][i >> 1] >> 1) | ((i & 1) << (k - 1));
        }
    }
    PolyLn(A, B, n);
    for (int i = 0; i < n; i++) {
        printf("%lld%c", B[i], " \n"[i == n - 1]);
    }
	return 0;
}

多项式指数函数(多项式 exp)

题目链接

给定一个多项式 f(x)f(x)​ ,请求出一个多项式 g(x)g(x)​, 满足 g(x)ef(x)(modxn)g(x) \equiv e^{f(x)} (\operatorname{mod} x^n)​。系数对 998244353998244353​ 取模。其中 degf=degg=n1\operatorname{deg} f = \operatorname{deg} g = n - 1​。

题目解答

带入牛顿迭代。

定义函数 h(g(x))=lng(x)f(x)\begin{aligned} h(g(x)) = \ln g(x) - f(x) \end{aligned} ,问题转化为在模 xnx^n 的意义下求出多项式 g(x)g(x) 满足 h(g(x))0(modxn)h(g(x)) \equiv 0 (\bmod x^n)

首先当 n=1n = 1 时,h(g(x))0(modxn)h(g(x)) \equiv 0 (\bmod x^n) 的解可以直接求出。

假设现在已经得到了模 xn2x^{\left\lceil\frac{n}{2}\right\rceil} 意义下的解 g0(x)g_0(x),要求模 xnx^n 意义下的解 g(x)g(x)

由牛顿迭代知:

g(x)g0(x)h(g0(x))h(g0(x))(modxn)g0(x)lng0(x)f(x)1g0(x)(modxn)g0(x)(1lng0(x)+f(x))(modxn)\begin{aligned} g(x) &\equiv g_0(x) - \frac{h(g_0(x))}{h'(g_0(x))} &(\bmod x^n) \\ &\equiv g_0(x) - \frac{\ln g_0(x) - f(x)}{\frac 1 {g_0(x)}} &(\bmod x^n) \\ &\equiv g_0(x)(1 - \ln g_0(x) + f(x)) &(\bmod x^n) \end{aligned}

时间复杂度 T(n)=T(n2)+O(nlogn)=O(nlogn)T\left(n\right)=T\left(\frac{n}{2}\right)+O\left(n\log{n}\right)=O\left(n\log{n}\right)

这里贴一下代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1 << 22 | 5;
const LL mod = 998244353ll, inv2 = 499122177ll;
LL A[N], B[N]; // f, g
LL fA[N], fB[N], fAns[N];
int rev[23][N];
inline LL qpow(LL n, LL p, LL mod = 998244353ll) {
    LL ret = 1ll;
    while (p) {
        if (p & 1) (ret *= n) %= mod;
        (n *= n) %= mod; p >>= 1;
    }
    return ret;
}
void NTT(LL *F, int *Rev, int n, int Sign) { // Sign = 1: NTT , Sign = -1: INTT
	for (int i = 0; i < n; i++) {
		if (i < Rev[i]) swap(F[i], F[Rev[i]]);
	}
	for (int len = 2; len <= n; len <<= 1) {
		int mid = len >> 1;
        LL wl = qpow(3, (mod - 1) / len);
		for (int i = 0; i < n; i += len) {
			LL Ang = 1;
			for (int j = 0; j < mid; j++) {
				LL tx = F[i + j], ty = (F[i + j + mid] * Ang) % mod;
				F[i + j] = (tx + ty) % mod; F[i + j + mid] = (tx - ty + mod) % mod;
				Ang = (Ang * wl) % mod;
			}
		}
	}
    if (Sign == -1) {
        for (int i = 1; i < (n >> 1); i++) swap(F[i], F[n - i]);
        LL Invn = qpow(n, mod - 2);
        for (int i = 0; i < n; i++) {
            (F[i] *= Invn) %= mod;
        }
    }
}
LL Tmp[N];
void Polyinv(LL *f, LL *g, int n) {
    if (n == 1) {g[0] = qpow(f[0], mod - 2); return;}
    Polyinv(f, g, (n + 1) >> 1);
    int p = 0; while ((1 << p) <= n + n) p++;
    for (int i = 0; i < (1 << p); i++) Tmp[i] = (i < n) ? f[i] : 0;
    NTT(Tmp, rev[p], 1 << p, 1); NTT(g, rev[p], 1 << p, 1);
    for (int i = 0; i < (1 << p); i++) Tmp[i] = ((g[i] * ((2ll - Tmp[i] * g[i]) % mod)) % mod + mod) % mod;
    NTT(Tmp, rev[p], 1 << p, -1);
    for (int i = 0; i < (1 << p); i++) g[i] = (i < n) ? Tmp[i] : 0;
}
LL ginv[N];
void PolySqrt(LL *f, LL *g, int n) {
    if (n == 1) {g[0] = 1; return;}
    PolySqrt(f, g, (n + 1) >> 1);
    int p = 0; while ((1 << p) <= n + n) p++;
    for (int i = 0; i < (1 << p); i++) ginv[i] = 0;
    Polyinv(g, ginv, n);
    for (int i = 0; i < (1 << p); i++) Tmp[i] = (i < n) ? f[i] : 0;
    NTT(Tmp, rev[p], 1 << p, 1); NTT(g, rev[p], 1 << p, 1); NTT(ginv, rev[p], 1 << p, 1);
    for (int i = 0; i < (1 << p); i++) Tmp[i] = (inv2 * ((g[i] + ginv[i] * Tmp[i]) % mod)) % mod;
    NTT(Tmp, rev[p], 1 << p, -1);
    for (int i = 0; i < (1 << p); i++) g[i] = (i < n) ? Tmp[i] : 0;
}
LL finv[N];
void PolyLn(LL *f, LL *g, int n) {
    int p = 0; while ((1 << p) <= n + n) p++;
    for (int i = 0; i < (1 << p); i++) finv[i] = 0;
    Polyinv(f, finv, n);
    for (int i = 0; i < n - 1; i++) g[i] = (f[i + 1] * (i + 1)) % mod;
    for (int i = n - 1; i < (1 << p); i++) g[i] = 0;
    for (int i = 0; i < (1 << p); i++) Tmp[i] = g[i];
    NTT(finv, rev[p], 1 << p, 1); NTT(Tmp, rev[p], 1 << p, 1);
    for (int i = 0; i < (1 << p); i++) Tmp[i] = (finv[i] * Tmp[i]) % mod;
    NTT(Tmp, rev[p], 1 << p, -1);
    g[0] = 0; for (int i = 1; i < (1 << p); i++) g[i] = (i < n) ? ((Tmp[i - 1] * qpow(i, mod - 2)) % mod) : 0;
}
LL gln[N];
void PolyExp(LL *f, LL *g, int n) {
	if (n == 1) {g[0] = 1; return;}
	PolyExp(f, g, (n + 1) >> 1);
	int p = 0; while ((1 << p) <= n + n) p++;
	for (int i = 0; i < (1 << p); i++) gln[i] = 0;
	PolyLn(g, gln, n);
	for (int i = 0; i < n; i++) gln[i] = ((i == 0) - gln[i] + f[i] + mod) % mod;
	for (int i = n; i < (1 << p); i++) gln[i] = 0;
	for (int i = 0; i < n; i++) Tmp[i] = g[i];
	for (int i = n; i < (1 << p); i++) Tmp[i] = 0;
	NTT(gln, rev[p], 1 << p, 1); NTT(Tmp, rev[p], 1 << p, 1);
	for (int i = 0; i < (1 << p); i++) Tmp[i] = (gln[i] * Tmp[i]) % mod;
	NTT(Tmp, rev[p], 1 << p, -1);
	for (int i = 0; i < (1 << p); i++) g[i] = (i < n) ? Tmp[i] : 0;
}
int main() {
	int n, m, p = 0;
	scanf("%d", &n);
	for (int i = 0; i < n; i++) scanf("%lld", A + i);
    for (int k = 1; k < 22; k++) {
        for (int i = 1, End = (1 << k); i < End; i++) {
            rev[k][i] = (rev[k][i >> 1] >> 1) | ((i & 1) << (k - 1));
        }
    }
    PolyExp(A, B, n);
    for (int i = 0; i < n; i++) {
        printf("%lld%c", B[i], " \n"[i == n - 1]);
    }
	return 0;
}

多项式快速幂

给定一个多项式 f(x)f(x)​​ ,请求出一个多项式 g(x)g(x)​​, 满足 g(x)(f(x))k(modxn)g(x) \equiv (f(x))^k (\operatorname{mod} x^n)​​。系数对 998244353998244353​​ 取模。其中 degf=degg=n1\operatorname{deg} f = \operatorname{deg} g = n - 1​​。

普通版

题目链接

有特殊性质:保证 [x0]f(x)=1[x^0]f(x) = 1​。

题目解答

因为 [x0]f(x)=1[x^0]f(x) = 1,所以可以直接求 lnf(x)\ln f(x) ,于是 g(x)(f(x))k(modxn)g(x)eklnf(x)(modxn)\begin{aligned} g(x) \equiv (f(x))^k (\operatorname{mod} x^n) \Leftrightarrow g(x) \equiv e^{k \ln f(x)} \end{aligned} (\operatorname{mod} x^n)

直接对 f(x)f(x)ln\ln,再乘 kk ,之后求 exp\exp​,就是答案了。

代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1 << 22 | 5;
const LL mod = 998244353ll, inv2 = 499122177ll;
LL A[N], B[N]; // f, g
LL fA[N], fB[N], fAns[N];
int rev[23][N];
inline LL qpow(LL n, LL p, LL mod = 998244353ll) {
    LL ret = 1ll;
    while (p) {
        if (p & 1) (ret *= n) %= mod;
        (n *= n) %= mod; p >>= 1;
    }
    return ret;
}
void NTT(LL *F, int *Rev, int n, int Sign) { // Sign = 1: NTT , Sign = -1: INTT
	for (int i = 0; i < n; i++) {
		if (i < Rev[i]) swap(F[i], F[Rev[i]]);
	}
	for (int len = 2; len <= n; len <<= 1) {
		int mid = len >> 1;
        LL wl = qpow(3, (mod - 1) / len);
		for (int i = 0; i < n; i += len) {
			LL Ang = 1;
			for (int j = 0; j < mid; j++) {
				LL tx = F[i + j], ty = (F[i + j + mid] * Ang) % mod;
				F[i + j] = (tx + ty) % mod; F[i + j + mid] = (tx - ty + mod) % mod;
				Ang = (Ang * wl) % mod;
			}
		}
	}
    if (Sign == -1) {
        for (int i = 1; i < (n >> 1); i++) swap(F[i], F[n - i]);
        LL Invn = qpow(n, mod - 2);
        for (int i = 0; i < n; i++) {
            (F[i] *= Invn) %= mod;
        }
    }
}
LL Tmp[N];
void Polyinv(LL *f, LL *g, int n) {
    if (n == 1) {g[0] = qpow(f[0], mod - 2); return;}
    Polyinv(f, g, (n + 1) >> 1);
    int p = 0; while ((1 << p) <= n + n) p++;
    for (int i = 0; i < (1 << p); i++) Tmp[i] = (i < n) ? f[i] : 0;
    NTT(Tmp, rev[p], 1 << p, 1); NTT(g, rev[p], 1 << p, 1);
    for (int i = 0; i < (1 << p); i++) Tmp[i] = ((g[i] * ((2ll - Tmp[i] * g[i]) % mod)) % mod + mod) % mod;
    NTT(Tmp, rev[p], 1 << p, -1);
    for (int i = 0; i < (1 << p); i++) g[i] = (i < n) ? Tmp[i] : 0;
}
LL ginv[N];
void PolySqrt(LL *f, LL *g, int n) {
    if (n == 1) {g[0] = 1; return;}
    PolySqrt(f, g, (n + 1) >> 1);
    int p = 0; while ((1 << p) <= n + n) p++;
    for (int i = 0; i < (1 << p); i++) ginv[i] = 0;
    Polyinv(g, ginv, n);
    for (int i = 0; i < (1 << p); i++) Tmp[i] = (i < n) ? f[i] : 0;
    NTT(Tmp, rev[p], 1 << p, 1); NTT(g, rev[p], 1 << p, 1); NTT(ginv, rev[p], 1 << p, 1);
    for (int i = 0; i < (1 << p); i++) Tmp[i] = (inv2 * ((g[i] + ginv[i] * Tmp[i]) % mod)) % mod;
    NTT(Tmp, rev[p], 1 << p, -1);
    for (int i = 0; i < (1 << p); i++) g[i] = (i < n) ? Tmp[i] : 0;
}
LL finv[N];
void PolyLn(LL *f, LL *g, int n) {
    int p = 0; while ((1 << p) <= n + n) p++;
    for (int i = 0; i < (1 << p); i++) finv[i] = 0;
    Polyinv(f, finv, n);
    for (int i = 0; i < n - 1; i++) g[i] = (f[i + 1] * (i + 1)) % mod;
    for (int i = n - 1; i < (1 << p); i++) g[i] = 0;
    for (int i = 0; i < (1 << p); i++) Tmp[i] = g[i];
    NTT(finv, rev[p], 1 << p, 1); NTT(Tmp, rev[p], 1 << p, 1);
    for (int i = 0; i < (1 << p); i++) Tmp[i] = (finv[i] * Tmp[i]) % mod;
    NTT(Tmp, rev[p], 1 << p, -1);
    g[0] = 0; for (int i = 1; i < (1 << p); i++) g[i] = (i < n) ? ((Tmp[i - 1] * qpow(i, mod - 2)) % mod) : 0;
}
LL gln[N];
void PolyExp(LL *f, LL *g, int n) {
	if (n == 1) {g[0] = 1; return;}
	PolyExp(f, g, (n + 1) >> 1);
	int p = 0; while ((1 << p) <= n + n) p++;
	for (int i = 0; i < (1 << p); i++) gln[i] = 0;
	PolyLn(g, gln, n);
	for (int i = 0; i < n; i++) gln[i] = ((i == 0) - gln[i] + f[i] + mod) % mod;
	for (int i = n; i < (1 << p); i++) gln[i] = 0;
	for (int i = 0; i < n; i++) Tmp[i] = g[i];
	for (int i = n; i < (1 << p); i++) Tmp[i] = 0;
	NTT(gln, rev[p], 1 << p, 1); NTT(Tmp, rev[p], 1 << p, 1);
	for (int i = 0; i < (1 << p); i++) Tmp[i] = (gln[i] * Tmp[i]) % mod;
	NTT(Tmp, rev[p], 1 << p, -1);
	for (int i = 0; i < (1 << p); i++) g[i] = (i < n) ? Tmp[i] : 0;
}
LL fln[N];
void PolyPower(LL *f, LL *g, int n, LL k1, LL k2) {
	int stpos = 0;
	for (int i = 0; i < n; i++) if (f[i] > 0) {stpos = i; break;}
	LL tInv = qpow(f[stpos], mod - 2);
	int p = 0; while ((1 << p) <= n - stpos + n - stpos) p++;
	for (int i = stpos; i < n; i++) fln[i - stpos] = f[i] * tInv;
	for (int i = n - stpos; i < (1 << p); i++) fln[i] = 0;
	PolyLn(fln, fln, n - stpos);
	k1 %= mod; for (int i = 0; i < (1 << p); i++) fln[i] = (k1 * fln[i]) % mod;
	PolyExp(fln, g, n - stpos);
	LL Shift = k1 * stpos; tInv = qpow(tInv, k2);
	for (int i = Shift; i < n; i++) g[i] = (g[i - Shift] * tInv) % mod;
	for (int i = 0; i < Shift; i++) g[i] = 0;
}
char ch[100005];
int main() {
	int n, m, p = 0;
	scanf("%d%s", &n, ch + 1);
	for (int i = 0; i < n; i++) scanf("%lld", A + i);
	LL k1 = 0, k2 = 0;
	int len = strlen(ch + 1);
	for (int i = 1; i <= len; i++) {
		k1 = (k1 * 10 + (ch[i] ^ 48)) % mod;
		k2 = (k2 * 10 + (ch[i] ^ 48)) % (mod - 1);
	}
    for (int k = 1; k < 22; k++) {
        for (int i = 1, End = (1 << k); i < End; i++) {
            rev[k][i] = (rev[k][i >> 1] >> 1) | ((i & 1) << (k - 1));
        }
    }
    PolyPower(A, B, n, k1, k2);
    for (int i = 0; i < n; i++) {
        printf("%lld%c", B[i], " \n"[i == n - 1]);
    }
	return 0;
}

加强版

题目链接

无特殊性质。

题目解答

需要先将 f(x)f(x) 变为 f(x)ak1xkf(x) \cdot a_k^{-1} \cdot x^{-k},其中 kk 为非 00 系数次项中的最低次数。变换后再进行普通版的操作,最后在变换回原来的。

注意特判。

代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1 << 22 | 5;
const LL mod = 998244353ll, inv2 = 499122177ll;
LL A[N], B[N]; // f, g
LL fA[N], fB[N], fAns[N];
int rev[23][N];
inline LL qpow(LL n, LL p, LL mod = 998244353ll) {
    LL ret = 1ll;
    while (p) {
        if (p & 1) (ret *= n) %= mod;
        (n *= n) %= mod; p >>= 1;
    }
    return ret;
}
void NTT(LL *F, int *Rev, int n, int Sign) { // Sign = 1: NTT , Sign = -1: INTT
	for (int i = 0; i < n; i++) {
		if (i < Rev[i]) swap(F[i], F[Rev[i]]);
	}
	for (int len = 2; len <= n; len <<= 1) {
		int mid = len >> 1;
        LL wl = qpow(3, (mod - 1) / len);
		for (int i = 0; i < n; i += len) {
			LL Ang = 1;
			for (int j = 0; j < mid; j++) {
				LL tx = F[i + j], ty = (F[i + j + mid] * Ang) % mod;
				F[i + j] = (tx + ty) % mod; F[i + j + mid] = (tx - ty + mod) % mod;
				Ang = (Ang * wl) % mod;
			}
		}
	}
    if (Sign == -1) {
        for (int i = 1; i < (n >> 1); i++) swap(F[i], F[n - i]);
        LL Invn = qpow(n, mod - 2);
        for (int i = 0; i < n; i++) {
            (F[i] *= Invn) %= mod;
        }
    }
}
LL Tmp[N];
void Polyinv(LL *f, LL *g, int n) {
    if (n == 1) {g[0] = qpow(f[0], mod - 2); return;}
    Polyinv(f, g, (n + 1) >> 1);
    int p = 0; while ((1 << p) <= n + n) p++;
    for (int i = 0; i < (1 << p); i++) Tmp[i] = (i < n) ? f[i] : 0;
    NTT(Tmp, rev[p], 1 << p, 1); NTT(g, rev[p], 1 << p, 1);
    for (int i = 0; i < (1 << p); i++) Tmp[i] = ((g[i] * ((2ll - Tmp[i] * g[i]) % mod)) % mod + mod) % mod;
    NTT(Tmp, rev[p], 1 << p, -1);
    for (int i = 0; i < (1 << p); i++) g[i] = (i < n) ? Tmp[i] : 0;
}
LL ginv[N];
void PolySqrt(LL *f, LL *g, int n) {
    if (n == 1) {g[0] = 1; return;}
    PolySqrt(f, g, (n + 1) >> 1);
    int p = 0; while ((1 << p) <= n + n) p++;
    for (int i = 0; i < (1 << p); i++) ginv[i] = 0;
    Polyinv(g, ginv, n);
    for (int i = 0; i < (1 << p); i++) Tmp[i] = (i < n) ? f[i] : 0;
    NTT(Tmp, rev[p], 1 << p, 1); NTT(g, rev[p], 1 << p, 1); NTT(ginv, rev[p], 1 << p, 1);
    for (int i = 0; i < (1 << p); i++) Tmp[i] = (inv2 * ((g[i] + ginv[i] * Tmp[i]) % mod)) % mod;
    NTT(Tmp, rev[p], 1 << p, -1);
    for (int i = 0; i < (1 << p); i++) g[i] = (i < n) ? Tmp[i] : 0;
}
LL finv[N];
void PolyLn(LL *f, LL *g, int n) {
    int p = 0; while ((1 << p) <= n + n) p++;
    for (int i = 0; i < (1 << p); i++) finv[i] = 0;
    Polyinv(f, finv, n);
    for (int i = 0; i < n - 1; i++) g[i] = (f[i + 1] * (i + 1)) % mod;
    for (int i = n - 1; i < (1 << p); i++) g[i] = 0;
    for (int i = 0; i < (1 << p); i++) Tmp[i] = g[i];
    NTT(finv, rev[p], 1 << p, 1); NTT(Tmp, rev[p], 1 << p, 1);
    for (int i = 0; i < (1 << p); i++) Tmp[i] = (finv[i] * Tmp[i]) % mod;
    NTT(Tmp, rev[p], 1 << p, -1);
    g[0] = 0; for (int i = 1; i < (1 << p); i++) g[i] = (i < n) ? ((Tmp[i - 1] * qpow(i, mod - 2)) % mod) : 0;
}
LL gln[N];
void PolyExp(LL *f, LL *g, int n) {
	if (n == 1) {g[0] = 1; return;}
	PolyExp(f, g, (n + 1) >> 1);
	int p = 0; while ((1 << p) <= n + n) p++;
	for (int i = 0; i < (1 << p); i++) gln[i] = 0;
	PolyLn(g, gln, n);
	for (int i = 0; i < n; i++) gln[i] = ((i == 0) - gln[i] + f[i] + mod) % mod;
	for (int i = n; i < (1 << p); i++) gln[i] = 0;
	for (int i = 0; i < n; i++) Tmp[i] = g[i];
	for (int i = n; i < (1 << p); i++) Tmp[i] = 0;
	NTT(gln, rev[p], 1 << p, 1); NTT(Tmp, rev[p], 1 << p, 1);
	for (int i = 0; i < (1 << p); i++) Tmp[i] = (gln[i] * Tmp[i]) % mod;
	NTT(Tmp, rev[p], 1 << p, -1);
	for (int i = 0; i < (1 << p); i++) g[i] = (i < n) ? Tmp[i] : 0;
}
LL fln[N];
void PolyPower(LL *f, LL *g, int n, LL k1, LL k2) {
	int stpos = 0;
	for (int i = 0; i < n; i++) if (f[i] > 0) {stpos = i; break;}
	LL tInv = qpow(f[stpos], mod - 2), Shift = k1 * (LL)stpos, tMul = qpow(f[stpos], k2);
	if (Shift >= n) {for (int i = 0; i < n; i++) g[i] = 0; return;}
	int p = 0; while ((1 << p) <= n - stpos + n - stpos) p++;
	for (int i = stpos; i < n; i++) fln[i - stpos] = (f[i] * tInv) % mod;
	PolyLn(fln, fln, n - stpos);
	for (int i = 0; i < (1 << p); i++) fln[i] = (k1 * fln[i]) % mod;
	PolyExp(fln, g, n - stpos);
	for (int i = n - 1; i >= Shift; i--) g[i] = (g[i - Shift] * tMul) % mod;
	for (int i = 0; i < Shift; i++) g[i] = 0;
}
char ch[100005];
int main() {
	int n, m, p = 0;
	scanf("%d%s", &n, ch + 1);
	for (int i = 0; i < n; i++) scanf("%lld", A + i);
	LL k1 = 0, k2 = 0;
	int len = strlen(ch + 1);
	for (int i = 1; i <= len; i++) {
		k1 = (k1 * 10 + (ch[i] ^ 48)) % mod;
		k2 = (k2 * 10 + (ch[i] ^ 48)) % (mod - 1);
	}
	if (A[0] == 0 && len > 6) {
		for (int i = 1; i <= n; i++) printf("0%c", " \n"[i == n]); return 0;
	}
	LL tk = 0;
	if (A[0] == 0 && len <= 6) {
		for (int i = 1; i <= len; i++) tk = (tk * 10 + (ch[i] ^ 48));
		for (int i = 0; i < n; i++) if (A[i] != 0) break;
		else if (i * tk >= n) {for (int i = 1; i <= n; i++) printf("0%c", " \n"[i == n]); return 0;}
	}
    for (int k = 1; k < 22; k++) {
        for (int i = 1, End = (1 << k); i < End; i++) {
            rev[k][i] = (rev[k][i >> 1] >> 1) | ((i & 1) << (k - 1));
        }
    }
    PolyPower(A, B, n, k1, k2);
    for (int i = 0; i < n; i++) {
        printf("%lld%c", B[i], " \n"[i == n - 1]);
    }
	return 0;
}